from .structs import ( infiniopHandle_t, infiniopTensorDescriptor_t, infiniopOperatorDescriptor_t, ) from ctypes import c_int32, c_void_p, c_size_t, POINTER, c_float class OpRegister: registry = [] @classmethod def operator(cls, op): cls.registry.append(op) return op @classmethod def register_lib(cls, lib): for op in cls.registry: op(lib) @OpRegister.operator def add_(lib): lib.infiniopCreateAddDescriptor.restype = c_int32 lib.infiniopCreateAddDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetAddWorkspaceSize.restype = c_int32 lib.infiniopGetAddWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopAdd.restype = c_int32 lib.infiniopAdd.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyAddDescriptor.restype = c_int32 lib.infiniopDestroyAddDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def attention_(lib): lib.infiniopCreateAttentionDescriptor.restype = c_int32 lib.infiniopCreateAttentionDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_size_t, ] lib.infiniopGetAttentionWorkspaceSize.restype = c_int32 lib.infiniopGetAttentionWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopAttention.restype = c_int32 lib.infiniopAttention.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyAttentionDescriptor.restype = c_int32 lib.infiniopDestroyAttentionDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def causal_softmax_(lib): lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32 lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, ] lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32 lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopCausalSoftmax.restype = c_int32 lib.infiniopCausalSoftmax.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, ] lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32 lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def clip_(lib): lib.infiniopCreateClipDescriptor.restype = c_int32 lib.infiniopCreateClipDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetClipWorkspaceSize.restype = c_int32 lib.infiniopGetClipWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopClip.restype = c_int32 lib.infiniopClip.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyClipDescriptor.restype = c_int32 lib.infiniopDestroyClipDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def logsoftmax_(lib): lib.infiniopCreateLogSoftmaxDescriptor.restype = c_int32 lib.infiniopCreateLogSoftmaxDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetLogSoftmaxWorkspaceSize.restype = c_int32 lib.infiniopGetLogSoftmaxWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopLogSoftmax.restype = c_int32 lib.infiniopLogSoftmax.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyLogSoftmaxDescriptor.restype = c_int32 lib.infiniopDestroyLogSoftmaxDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def conv_(lib): pass @OpRegister.operator def gemm_(lib): lib.infiniopCreateGemmDescriptor.restype = c_int32 lib.infiniopCreateGemmDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetGemmWorkspaceSize.restype = c_int32 lib.infiniopGetGemmWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopGemm.restype = c_int32 lib.infiniopGemm.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_float, c_float, c_void_p, ] lib.infiniopDestroyGemmDescriptor.restype = c_int32 lib.infiniopDestroyGemmDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def mul_(lib): lib.infiniopCreateMulDescriptor.restype = c_int32 lib.infiniopCreateMulDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetMulWorkspaceSize.restype = c_int32 lib.infiniopGetMulWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopMul.restype = c_int32 lib.infiniopMul.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyMulDescriptor.restype = c_int32 lib.infiniopDestroyMulDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def random_sample_(lib): lib.infiniopCreateRandomSampleDescriptor.restype = c_int32 lib.infiniopCreateRandomSampleDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, ] lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32 lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopRandomSample.restype = c_int32 lib.infiniopRandomSample.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_size_t, c_void_p, c_float, c_float, c_int32, c_float, c_void_p, ] lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32 lib.infiniopDestroyRandomSampleDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def rearrange_(lib): lib.infiniopCreateRearrangeDescriptor.restype = c_int32 lib.infiniopCreateRearrangeDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopRearrange.restype = c_int32 lib.infiniopRearrange.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyRearrangeDescriptor.restype = c_int32 lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopOperatorDescriptor_t] @OpRegister.operator def relu_(lib): lib.infiniopCreateReluDescriptor.restype = c_int32 lib.infiniopCreateReluDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopRelu.restype = c_int32 lib.infiniopRelu.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyReluDescriptor.restype = c_int32 lib.infiniopDestroyReluDescriptor.argtypes = [infiniopOperatorDescriptor_t] @OpRegister.operator def rms_norm_(lib): lib.infiniopCreateRMSNormDescriptor.restype = c_int32 lib.infiniopCreateRMSNormDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_float, ] lib.infiniopGetRMSNormWorkspaceSize.restype = c_int32 lib.infiniopGetRMSNormWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopRMSNorm.restype = c_int32 lib.infiniopRMSNorm.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyRMSNormDescriptor.restype = c_int32 lib.infiniopDestroyRMSNormDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def rope_(lib): lib.infiniopCreateRoPEDescriptor.restype = c_int32 lib.infiniopCreateRoPEDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_int32, ] lib.infiniopGetRoPEWorkspaceSize.restype = c_int32 lib.infiniopGetRoPEWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopRoPE.restype = c_int32 lib.infiniopRoPE.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyRoPEDescriptor.restype = c_int32 lib.infiniopDestroyRoPEDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def sub_(lib): lib.infiniopCreateSubDescriptor.restype = c_int32 lib.infiniopCreateSubDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetSubWorkspaceSize.restype = c_int32 lib.infiniopGetSubWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopSub.restype = c_int32 lib.infiniopSub.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySubDescriptor.restype = c_int32 lib.infiniopDestroySubDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def swiglu_(lib): lib.infiniopCreateSwiGLUDescriptor.restype = c_int32 lib.infiniopCreateSwiGLUDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetSwiGLUWorkspaceSize.restype = c_int32 lib.infiniopGetSwiGLUWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopSwiGLU.restype = c_int32 lib.infiniopSwiGLU.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySwiGLUDescriptor.restype = c_int32 lib.infiniopDestroySwiGLUDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def conv_(lib): lib.infiniopCreateConvDescriptor.restype = c_int32 lib.infiniopCreateConvDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_void_p, c_void_p, c_void_p, c_size_t, ] lib.infiniopGetConvWorkspaceSize.restype = c_int32 lib.infiniopGetConvWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopConv.restype = c_int32 lib.infiniopConv.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyConvDescriptor.restype = c_int32 lib.infiniopDestroyConvDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def sigmoid_(lib): lib.infiniopCreateSigmoidDescriptor.restype = c_int32 lib.infiniopCreateSigmoidDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetSigmoidWorkspaceSize.restype = c_int32 lib.infiniopGetSigmoidWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopSigmoid.restype = c_int32 lib.infiniopSigmoid.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySigmoidDescriptor.restype = c_int32 lib.infiniopDestroySigmoidDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def topksoftmax_(lib): lib.infiniopCreateTopksoftmaxDescriptor.restype = c_int32 lib.infiniopCreateTopksoftmaxDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, ] lib.infiniopGetTopksoftmaxWorkspaceSize.restype = c_int32 lib.infiniopGetTopksoftmaxWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopTopksoftmax.restype = c_int32 lib.infiniopTopksoftmax.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_size_t, c_int32, c_void_p, ] lib.infiniopDestroyTopksoftmaxDescriptor.restype = c_int32 lib.infiniopDestroyTopksoftmaxDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def topkrouter_(lib): lib.infiniopCreateTopkrouterDescriptor.restype = c_int32 lib.infiniopCreateTopkrouterDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t ] lib.infiniopGetTopkrouterWorkspaceSize.restype = c_int32 lib.infiniopGetTopkrouterWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopTopkrouter.restype = c_int32 lib.infiniopTopkrouter.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, c_float, c_size_t, c_void_p, ] lib.infiniopDestroyTopkrouterDescriptor.restype = c_int32 lib.infiniopDestroyTopkrouterDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def dequantize_(lib): lib.infiniopCreateDequantizeAWQDescriptor.restype = c_int32 lib.infiniopCreateDequantizeAWQDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetDequantizeAWQWorkspaceSize.restype = c_int32 lib.infiniopGetDequantizeAWQWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopDequantizeAWQ.restype = c_int32 lib.infiniopDequantizeAWQ.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyDequantizeAWQDescriptor.restype = c_int32 lib.infiniopDestroyDequantizeAWQDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32 lib.infiniopCreateSoftplusDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopSoftplus.restype = c_int32 lib.infiniopSoftplus.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroySoftplusDescriptor.restype = c_int32 lib.infiniopDestroySoftplusDescriptor.argtypes = [infiniopOperatorDescriptor_t] @OpRegister.operator def zeros_(lib): lib.infiniopCreateZerosDescriptor.restype = c_int32 lib.infiniopCreateZerosDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetZerosWorkspaceSize.restype = c_int32 lib.infiniopGetZerosWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopZeros.restype = c_int32 lib.infiniopZeros.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyZerosDescriptor.restype = c_int32 lib.infiniopDestroyZerosDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] @OpRegister.operator def ones_(lib): lib.infiniopCreateOnesDescriptor.restype = c_int32 lib.infiniopCreateOnesDescriptor.argtypes = [ infiniopHandle_t, POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, ] lib.infiniopGetOnesWorkspaceSize.restype = c_int32 lib.infiniopGetOnesWorkspaceSize.argtypes = [ infiniopOperatorDescriptor_t, POINTER(c_size_t), ] lib.infiniopOnes.restype = c_int32 lib.infiniopOnes.argtypes = [ infiniopOperatorDescriptor_t, c_void_p, c_size_t, c_void_p, c_void_p, c_void_p, ] lib.infiniopDestroyOnesDescriptor.restype = c_int32 lib.infiniopDestroyOnesDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ]