from .structs import ( infiniopHandle_t, infiniopTensorDescriptor_t, infiniopOperatorDescriptor_t, ) from ctypes import c_int32, c_void_p, c_size_t, POINTER, c_float class OpRegister: registry = [] @classmethod def operator(cls, op): cls.registry.append(op) return op @classmethod def register_lib(cls, lib): for op in cls.registry: op(lib) @OpRegister.operator def add_(lib): lib.infiniopCreateAddDescriptor.restype = c_int32 lib.infiniopCreateAddDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetAddWorkspaceSize.restype = c_int32 lib.infiniopGetAddWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopAdd.restype = c_int32 lib.infiniopAdd.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyAddDescriptor.restype = c_int32 lib.infiniopDestroyAddDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def equal_(lib): # ========================================================= # 1. 注册 Create 函数 # C函数签名: (handle, &desc, output_desc, input_a_desc, input_b_desc) # ========================================================= lib.infiniopCreateEqualDescriptor.restype = c_int32 lib.infiniopCreateEqualDescriptor.argtypes = [ infiniopHandle_t, # handle POINTER(infiniopOperatorDescriptor_t),# desc_ptr (输出) infiniopTensorDescriptor_t, # output (c) infiniopTensorDescriptor_t, # input_a infiniopTensorDescriptor_t, # input_b ] # ========================================================= # 2. 注册 GetWorkspaceSize 函数 # C函数签名: (desc, &size) # ========================================================= lib.infiniopGetEqualWorkspaceSize.restype = c_int32 lib.infiniopGetEqualWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] # ========================================================= # 3. 注册 Execute (计算) 函数 # C函数签名: (desc, workspace, size, output_data, input_a_data, input_b_data, stream) # ========================================================= lib.infiniopEqual.restype = c_int32 lib.infiniopEqual.argtypes = [ infiniopOperatorDescriptor_t, # desc c_void_p, # workspace ptr c_size_t, # workspace size c_void_p, # output data ptr c_void_p, # input a data ptr c_void_p, # input b data ptr c_void_p, # stream ] # ========================================================= # 4. 注册 Destroy 函数 # C函数签名: (desc) # ========================================================= lib.infiniopDestroyEqualDescriptor.restype = c_int32 lib.infiniopDestroyEqualDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def attention_(lib): lib.infiniopCreateAttentionDescriptor.restype = c_int32 lib.infiniopCreateAttentionDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_size_t, ] lib.infiniopGetAttentionWorkspaceSize.restype = c_int32 lib.infiniopGetAttentionWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopAttention.restype = c_int32 lib.infiniopAttention.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyAttentionDescriptor.restype = c_int32 lib.infiniopDestroyAttentionDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def causal_softmax_(lib): lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32 lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, ] lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32 lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopCausalSoftmax.restype = c_int32 lib.infiniopCausalSoftmax.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, ] lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32 lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def clip_(lib): lib.infiniopCreateClipDescriptor.restype = c_int32 lib.infiniopCreateClipDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetClipWorkspaceSize.restype = c_int32 lib.infiniopGetClipWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopClip.restype = c_int32 lib.infiniopClip.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyClipDescriptor.restype = c_int32 lib.infiniopDestroyClipDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def cross_entropy_(lib): lib.infiniopCreateCrossEntropyDescriptor.restype = c_int32 lib.infiniopCreateCrossEntropyDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetCrossEntropyWorkspaceSize.restype = c_int32 lib.infiniopGetCrossEntropyWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopCrossEntropy.restype = c_int32 lib.infiniopCrossEntropy.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyCrossEntropyDescriptor.restype = c_int32 lib.infiniopDestroyCrossEntropyDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def logsoftmax_(lib): lib.infiniopCreateLogSoftmaxDescriptor.restype = c_int32 lib.infiniopCreateLogSoftmaxDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetLogSoftmaxWorkspaceSize.restype = c_int32 lib.infiniopGetLogSoftmaxWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopLogSoftmax.restype = c_int32 lib.infiniopLogSoftmax.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyLogSoftmaxDescriptor.restype = c_int32 lib.infiniopDestroyLogSoftmaxDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def conv_(lib): pass @OpRegister.operator def gemm_(lib): lib.infiniopCreateGemmDescriptor.restype = c_int32 lib.infiniopCreateGemmDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetGemmWorkspaceSize.restype = c_int32 lib.infiniopGetGemmWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopGemm.restype = c_int32 lib.infiniopGemm.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_float, c_float, c_void_p, ] lib.infiniopDestroyGemmDescriptor.restype = c_int32 lib.infiniopDestroyGemmDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def mul_(lib): lib.infiniopCreateMulDescriptor.restype = c_int32 lib.infiniopCreateMulDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetMulWorkspaceSize.restype = c_int32 lib.infiniopGetMulWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopMul.restype = c_int32 lib.infiniopMul.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyMulDescriptor.restype = c_int32 lib.infiniopDestroyMulDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def random_sample_(lib): lib.infiniopCreateRandomSampleDescriptor.restype = c_int32 lib.infiniopCreateRandomSampleDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, ] lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32 lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopRandomSample.restype = c_int32 lib.infiniopRandomSample.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_size_t, c_void_p, c_float, c_float, c_int32, c_float, c_void_p, ] lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32 lib.infiniopDestroyRandomSampleDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def rearrange_(lib): lib.infiniopCreateRearrangeDescriptor.restype = c_int32 lib.infiniopCreateRearrangeDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopRearrange.restype = c_int32 lib.infiniopRearrange.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyRearrangeDescriptor.restype = c_int32 lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopOperatorDescriptor_t] @OpRegister.operator def relu_(lib): lib.infiniopCreateReluDescriptor.restype = c_int32 lib.infiniopCreateReluDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopRelu.restype = c_int32 lib.infiniopRelu.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyReluDescriptor.restype = c_int32 lib.infiniopDestroyReluDescriptor.argtypes = [infiniopOperatorDescriptor_t] @OpRegister.operator def rms_norm_(lib): lib.infiniopCreateRMSNormDescriptor.restype = c_int32 lib.infiniopCreateRMSNormDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_float, ] lib.infiniopGetRMSNormWorkspaceSize.restype = c_int32 lib.infiniopGetRMSNormWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopRMSNorm.restype = c_int32 lib.infiniopRMSNorm.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyRMSNormDescriptor.restype = c_int32 lib.infiniopDestroyRMSNormDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def add_rms_norm_(lib): lib.infiniopCreateAddRMSNormDescriptor.restype = c_int32 lib.infiniopCreateAddRMSNormDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_float, ] lib.infiniopGetAddRMSNormWorkspaceSize.restype = c_int32 lib.infiniopGetAddRMSNormWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopAddRMSNorm.restype = c_int32 lib.infiniopAddRMSNorm.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32 lib.infiniopDestroyAddRMSNormDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def rope_(lib): lib.infiniopCreateRoPEDescriptor.restype = c_int32 lib.infiniopCreateRoPEDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_int32, ] lib.infiniopGetRoPEWorkspaceSize.restype = c_int32 lib.infiniopGetRoPEWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopRoPE.restype = c_int32 lib.infiniopRoPE.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyRoPEDescriptor.restype = c_int32 lib.infiniopDestroyRoPEDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def sub_(lib): lib.infiniopCreateSubDescriptor.restype = c_int32 lib.infiniopCreateSubDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetSubWorkspaceSize.restype = c_int32 lib.infiniopGetSubWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopSub.restype = c_int32 lib.infiniopSub.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySubDescriptor.restype = c_int32 lib.infiniopDestroySubDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def softmax_(lib): lib.infiniopCreateSoftmaxDescriptor.restype = c_int32 lib.infiniopCreateSoftmaxDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_int32, ] lib.infiniopGetSoftmaxWorkspaceSize.restype = c_int32 lib.infiniopGetSoftmaxWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopSoftmax.restype = c_int32 lib.infiniopSoftmax.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySoftmaxDescriptor.restype = c_int32 lib.infiniopDestroySoftmaxDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def swiglu_(lib): lib.infiniopCreateSwiGLUDescriptor.restype = c_int32 lib.infiniopCreateSwiGLUDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetSwiGLUWorkspaceSize.restype = c_int32 lib.infiniopGetSwiGLUWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopSwiGLU.restype = c_int32 lib.infiniopSwiGLU.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySwiGLUDescriptor.restype = c_int32 lib.infiniopDestroySwiGLUDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def conv_(lib): lib.infiniopCreateConvDescriptor.restype = c_int32 lib.infiniopCreateConvDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_void_p, c_void_p, c_void_p, c_size_t, ] lib.infiniopGetConvWorkspaceSize.restype = c_int32 lib.infiniopGetConvWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopConv.restype = c_int32 lib.infiniopConv.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyConvDescriptor.restype = c_int32 lib.infiniopDestroyConvDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def sigmoid_(lib): lib.infiniopCreateSigmoidDescriptor.restype = c_int32 lib.infiniopCreateSigmoidDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetSigmoidWorkspaceSize.restype = c_int32 lib.infiniopGetSigmoidWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopSigmoid.restype = c_int32 lib.infiniopSigmoid.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySigmoidDescriptor.restype = c_int32 lib.infiniopDestroySigmoidDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def topksoftmax_(lib): lib.infiniopCreateTopksoftmaxDescriptor.restype = c_int32 lib.infiniopCreateTopksoftmaxDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, ] lib.infiniopGetTopksoftmaxWorkspaceSize.restype = c_int32 lib.infiniopGetTopksoftmaxWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopTopksoftmax.restype = c_int32 lib.infiniopTopksoftmax.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_size_t, c_int32, c_void_p, ] lib.infiniopDestroyTopksoftmaxDescriptor.restype = c_int32 lib.infiniopDestroyTopksoftmaxDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def topkrouter_(lib): lib.infiniopCreateTopkrouterDescriptor.restype = c_int32 lib.infiniopCreateTopkrouterDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetTopkrouterWorkspaceSize.restype = c_int32 lib.infiniopGetTopkrouterWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopTopkrouter.restype = c_int32 lib.infiniopTopkrouter.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_float, c_size_t, c_void_p, ] lib.infiniopDestroyTopkrouterDescriptor.restype = c_int32 lib.infiniopDestroyTopkrouterDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def dequantize_(lib): lib.infiniopCreateDequantizeAWQDescriptor.restype = c_int32 lib.infiniopCreateDequantizeAWQDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetDequantizeAWQWorkspaceSize.restype = c_int32 lib.infiniopGetDequantizeAWQWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopDequantizeAWQ.restype = c_int32 lib.infiniopDequantizeAWQ.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyDequantizeAWQDescriptor.restype = c_int32 lib.infiniopDestroyDequantizeAWQDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def per_channel_quant_int8_(lib): lib.infiniopCreatePerChannelQuantI8Descriptor.restype = c_int32 lib.infiniopCreatePerChannelQuantI8Descriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetPerChannelQuantI8WorkspaceSize.restype = c_int32 lib.infiniopGetPerChannelQuantI8WorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopPerChannelQuantI8.restype = c_int32 lib.infiniopPerChannelQuantI8.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyPerChannelQuantI8Descriptor.restype = c_int32 lib.infiniopDestroyPerChannelQuantI8Descriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32 lib.infiniopCreateSoftplusDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopSoftplus.restype = c_int32 lib.infiniopSoftplus.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySoftplusDescriptor.restype = c_int32 lib.infiniopDestroySoftplusDescriptor.argtypes = [infiniopOperatorDescriptor_t] @OpRegister.operator def zeros_(lib): lib.infiniopCreateZerosDescriptor.restype = c_int32 lib.infiniopCreateZerosDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetZerosWorkspaceSize.restype = c_int32 lib.infiniopGetZerosWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopZeros.restype = c_int32 lib.infiniopZeros.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyZerosDescriptor.restype = c_int32 lib.infiniopDestroyZerosDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def ones_(lib): lib.infiniopCreateOnesDescriptor.restype = c_int32 lib.infiniopCreateOnesDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetOnesWorkspaceSize.restype = c_int32 lib.infiniopGetOnesWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopOnes.restype = c_int32 lib.infiniopOnes.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyOnesDescriptor.restype = c_int32 lib.infiniopDestroyOnesDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def gelu_(lib): lib.infiniopCreateGeluDescriptor.restype = c_int32 lib.infiniopCreateGeluDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetGeluWorkspaceSize.restype = c_int32 lib.infiniopGetGeluWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopGelu.restype = c_int32 lib.infiniopGelu.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyGeluDescriptor.restype = c_int32 lib.infiniopDestroyGeluDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def silu_(lib): lib.infiniopCreateSiluDescriptor.restype = c_int32 lib.infiniopCreateSiluDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetSiluWorkspaceSize.restype = c_int32 lib.infiniopGetSiluWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopSilu.restype = c_int32 lib.infiniopSilu.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySiluDescriptor.restype = c_int32 lib.infiniopDestroySiluDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def hardtanh_(lib): # 1. Create Descriptor - 注意增加了两个 c_float 参数 lib.infiniopCreateHardTanhDescriptor.restype = c_int32 lib.infiniopCreateHardTanhDescriptor.argtypes = [ infiniopHandle_t, # handle POINTER(infiniopOperatorDescriptor_t), # desc_ptr infiniopTensorDescriptor_t, # output infiniopTensorDescriptor_t, # input c_float, # min_val c_float, # max_val ] # 2. Get Workspace Size lib.infiniopGetHardTanhWorkspaceSize.restype = c_int32 lib.infiniopGetHardTanhWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, # desc POINTER(c_size_t), # size ] # 3. Execute Operator lib.infiniopHardTanh.restype = c_int32 lib.infiniopHardTanh.argtypes = [ infiniopOperatorDescriptor_t, # desc c_void_p, # workspace c_size_t, # workspace_size c_void_p, # output c_void_p, # input c_void_p, # stream ] # 4. Destroy Descriptor lib.infiniopDestroyHardTanhDescriptor.restype = c_int32 lib.infiniopDestroyHardTanhDescriptor.argtypes = [ infiniopOperatorDescriptor_t, # desc ] @OpRegister.operator def hardswish_(lib): lib.infiniopCreateHardSwishDescriptor.restype = c_int32 lib.infiniopCreateHardSwishDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetHardSwishWorkspaceSize.restype = c_int32 lib.infiniopGetHardSwishWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopHardSwish.restype = c_int32 lib.infiniopHardSwish.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyHardSwishDescriptor.restype = c_int32 lib.infiniopDestroyHardSwishDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def avg_pool1d_(lib): # 1. Create 函数 # C签名: (handle, *desc, y, x, kernel_size, stride, padding) lib.infiniopCreateAvgPool1dDescriptor.restype = c_int32 lib.infiniopCreateAvgPool1dDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, # y_desc (Output) infiniopTensorDescriptor_t, # x_desc (Input) c_size_t, # kernel_size c_size_t, # stride c_size_t, # padding ] # 2. GetWorkspaceSize 函数 lib.infiniopGetAvgPool1dWorkspaceSize.restype = c_int32 lib.infiniopGetAvgPool1dWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] # 3. Execute 函数 lib.infiniopAvgPool1d.restype = c_int32 lib.infiniopAvgPool1d.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, # workspace c_size_t, # workspace_size c_void_p, # y (output pointer) c_void_p, # x (input pointer) c_void_p, # stream ] # 4. Destroy 函数 lib.infiniopDestroyAvgPool1dDescriptor.restype = c_int32 lib.infiniopDestroyAvgPool1dDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def layer_norm_(lib): lib.infiniopCreateLayerNormDescriptor.restype = c_int32 lib.infiniopCreateLayerNormDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_float, ] lib.infiniopGetLayerNormWorkspaceSize.restype = c_int32 lib.infiniopGetLayerNormWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopLayerNorm.restype = c_int32 lib.infiniopLayerNorm.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyLayerNormDescriptor.restype = c_int32 lib.infiniopDestroyLayerNormDescriptor.argtypes = [infiniopOperatorDescriptor_t] @OpRegister.operator def lp_norm_(lib): lib.infiniopCreateLPNormDescriptor.restype = c_int32 lib.infiniopCreateLPNormDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_int32, c_int32, c_float, ] lib.infiniopGetLPNormWorkspaceSize.restype = c_int32 lib.infiniopGetLPNormWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopLPNorm.restype = c_int32 lib.infiniopLPNorm.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyLPNormDescriptor.restype = c_int32 lib.infiniopDestroyLPNormDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def tanh_(lib): lib.infiniopCreateTanhDescriptor.restype = c_int32 lib.infiniopCreateTanhDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetTanhWorkspaceSize.restype = c_int32 lib.infiniopGetTanhWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopTanh.restype = c_int32 lib.infiniopTanh.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyTanhDescriptor.restype = c_int32 lib.infiniopDestroyTanhDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def scaled_mm_int8_(lib): lib.infiniopCreateI8GemmDescriptor.restype = c_int32 lib.infiniopCreateI8GemmDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetI8GemmWorkspaceSize.restype = c_int32 lib.infiniopGetI8GemmWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopI8Gemm.restype = c_int32 lib.infiniopI8Gemm.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyI8GemmDescriptor.restype = c_int32 lib.infiniopDestroyI8GemmDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def kv_caching_(lib): lib.infiniopCreateKVCachingDescriptor.restype = c_int32 lib.infiniopCreateKVCachingDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetKVCachingWorkspaceSize.restype = c_int32 lib.infiniopGetKVCachingWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopKVCaching.restype = c_int32 lib.infiniopKVCaching.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyKVCachingDescriptor.restype = c_int32 lib.infiniopDestroyKVCachingDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def paged_attention_(lib): lib.infiniopCreatePagedAttentionDescriptor.restype = c_int32 lib.infiniopCreatePagedAttentionDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_void_p, c_float, ] lib.infiniopGetPagedAttentionWorkspaceSize.restype = c_int32 lib.infiniopGetPagedAttentionWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopPagedAttention.restype = c_int32 lib.infiniopPagedAttention.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyPagedAttentionDescriptor.restype = c_int32 lib.infiniopDestroyPagedAttentionDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def paged_caching_(lib): lib.infiniopCreatePagedCachingDescriptor.restype = c_int32 lib.infiniopCreatePagedCachingDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, # k_cache_desc infiniopTensorDescriptor_t, # v_cache_desc infiniopTensorDescriptor_t, # k_desc infiniopTensorDescriptor_t, # v_desc infiniopTensorDescriptor_t, # slot_mapping_desc ] # infiniopGetPagedCachingWorkspaceSize lib.infiniopGetPagedCachingWorkspaceSize.restype = c_int32 lib.infiniopGetPagedCachingWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] # infiniopPagedCaching lib.infiniopPagedCaching.restype = c_int32 lib.infiniopPagedCaching.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, # workspace c_size_t, # workspace_size c_void_p, # k_cache c_void_p, # v_cache c_void_p, # k c_void_p, # v c_void_p, # slot_mapping c_void_p, # stream ] # infiniopDestroyPagedCachingDescriptor lib.infiniopDestroyPagedCachingDescriptor.restype = c_int32 lib.infiniopDestroyPagedCachingDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def paged_attention_prefill_(lib): lib.infiniopCreatePagedAttentionPrefillDescriptor.restype = c_int32 lib.infiniopCreatePagedAttentionPrefillDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_float, ] lib.infiniopGetPagedAttentionPrefillWorkspaceSize.restype = c_int32 lib.infiniopGetPagedAttentionPrefillWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopPagedAttentionPrefill.restype = c_int32 lib.infiniopPagedAttentionPrefill.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyPagedAttentionPrefillDescriptor.restype = c_int32 lib.infiniopDestroyPagedAttentionPrefillDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def silu_and_mul(lib): lib.infiniopCreateSiluAndMulDescriptor.restype = c_int32 lib.infiniopCreateSiluAndMulDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetSiluAndMulWorkspaceSize.restype = c_int32 lib.infiniopGetSiluAndMulWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopSiluAndMul.restype = c_int32 lib.infiniopSiluAndMul.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySiluAndMulDescriptor.restype = c_int32 lib.infiniopDestroySiluAndMulDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ]