libinfinicore_infer.py 3.21 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
import ctypes
PanZezhong's avatar
PanZezhong committed
2
from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER
PanZezhong's avatar
init  
PanZezhong committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os


class DataType(ctypes.c_int):
    INFINI_DTYPE_INVALID = 0
    INFINI_DTYPE_BYTE = 1
    INFINI_DTYPE_BOOL = 2
    INFINI_DTYPE_I8 = 3
    INFINI_DTYPE_I16 = 4
    INFINI_DTYPE_I32 = 5
    INFINI_DTYPE_I64 = 6
    INFINI_DTYPE_U8 = 7
    INFINI_DTYPE_U16 = 8
    INFINI_DTYPE_U32 = 9
    INFINI_DTYPE_U64 = 10
    INFINI_DTYPE_F8 = 11
    INFINI_DTYPE_F16 = 12
    INFINI_DTYPE_F32 = 13
    INFINI_DTYPE_F64 = 14
    INFINI_DTYPE_C16 = 15
    INFINI_DTYPE_C32 = 16
    INFINI_DTYPE_C64 = 17
    INFINI_DTYPE_C128 = 18
    INFINI_DTYPE_BF16 = 19


class DeviceType(ctypes.c_int):
    DEVICE_TYPE_CPU = 0
PanZezhong's avatar
PanZezhong committed
31
    DEVICE_TYPE_NVIDIA = 1
PanZezhong's avatar
init  
PanZezhong committed
32
33
34
35
36
37
38
39
40
    DEVICE_TYPE_CAMBRICON = 2
    DEVICE_TYPE_ASCEND = 3
    DEVICE_TYPE_METAX = 4
    DEVICE_TYPE_MOORE = 5


class JiugeMeta(ctypes.Structure):
    _fields_ = [
        ("dt_logits", DataType),
PanZezhong's avatar
PanZezhong committed
41
42
43
44
45
46
47
48
        ("nlayer", c_size_t),
        ("d", c_size_t),
        ("nh", c_size_t),
        ("nkvh", c_size_t),
        ("dh", c_size_t),
        ("di", c_size_t),
        ("dctx", c_size_t),
        ("dvoc", c_size_t),
PanZezhong's avatar
init  
PanZezhong committed
49
50
51
52
53
54
55
56
57
        ("epsilon", c_float),
        ("theta", c_float),
        ("end_token", c_uint),
    ]


# Define the JiugeWeights struct
class JiugeWeights(ctypes.Structure):
    _fields_ = [
PanZezhong's avatar
PanZezhong committed
58
        ("nlayer", c_size_t),
PanZezhong's avatar
PanZezhong committed
59
60
        ("dt_norm", DataType),
        ("dt_mat", DataType),
PanZezhong's avatar
init  
PanZezhong committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        ("input_embd", c_void_p),
        ("output_norm", c_void_p),
        ("output_embd", c_void_p),
        ("attn_norm", POINTER(c_void_p)),
        ("attn_qkv", POINTER(c_void_p)),
        ("attn_qkv_b", POINTER(c_void_p)),
        ("attn_o", POINTER(c_void_p)),
        ("ffn_norm", POINTER(c_void_p)),
        ("ffn_gate_up", POINTER(c_void_p)),
        ("ffn_down", POINTER(c_void_p)),
    ]


class JiugeModel(ctypes.Structure):
    pass


class KVCache(ctypes.Structure):
    pass


PanZezhong's avatar
PanZezhong committed
82
def __open_library__():
PanZezhong's avatar
init  
PanZezhong committed
83
84
85
86
    lib_path = os.path.join(
        os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so"
    )
    lib = ctypes.CDLL(lib_path)
PanZezhong's avatar
PanZezhong committed
87
88
    lib.createJiugeModel.restype = POINTER(JiugeModel)
    lib.createJiugeModel.argtypes = [
PanZezhong's avatar
init  
PanZezhong committed
89
90
91
92
93
94
95
        POINTER(JiugeMeta),  # JiugeMeta const *
        POINTER(JiugeWeights),  # JiugeWeights const *
        DeviceType,  # DeviceType
        c_int,  # int ndev
        POINTER(c_int),  # int const *dev_ids
    ]

PanZezhong's avatar
PanZezhong committed
96
97
    lib.createKVCache.restype = POINTER(KVCache)
    lib.dropKVCache.argtypes = [ctypes.POINTER(JiugeModel), POINTER(KVCache)]
PanZezhong's avatar
init  
PanZezhong committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    lib.inferBatch.restype = None
    lib.inferBatch.argtypes = [
        ctypes.POINTER(JiugeModel),  # struct JiugeModel const *
        POINTER(c_uint),  # unsigned int const *tokens
        c_uint,  # unsigned int ntok
        POINTER(c_uint),  # unsigned int const *req_lens
        c_uint,  # unsigned int nreq
        POINTER(c_uint),  # unsigned int const *req_pos
        POINTER(POINTER(KVCache)),  # struct KVCache **kv_caches
        POINTER(c_uint),  # unsigned int *output
        c_float,  # float temperature
        c_uint,  # unsigned int topk
        c_float,  # float topp
    ]

    return lib
PanZezhong's avatar
PanZezhong committed
114
115
116
117
118
119
120
121


LIB = __open_library__()

create_jiuge_model = LIB.createJiugeModel
create_kv_cache = LIB.createKVCache
drop_kv_cache = LIB.dropKVCache
infer_batch = LIB.inferBatch