Unverified Commit 3b5afffe authored by Haojie Wang's avatar Haojie Wang Committed by GitHub
Browse files

Merge pull request #842 from gongchensu/Issue/791

Issue/791 增加add_rms_norm融合算子
parents 2d9d5c30 7712471f
......@@ -383,6 +383,43 @@ def rms_norm_(lib):
]
@OpRegister.operator
def add_rms_norm_(lib):
lib.infiniopCreateAddRMSNormDescriptor.restype = c_int32
lib.infiniopCreateAddRMSNormDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
]
lib.infiniopGetAddRMSNormWorkspaceSize.restype = c_int32
lib.infiniopGetAddRMSNormWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopAddRMSNorm.restype = c_int32
lib.infiniopAddRMSNorm.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32
lib.infiniopDestroyAddRMSNormDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def rope_(lib):
lib.infiniopCreateRoPEDescriptor.restype = c_int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment