Unverified Commit 3998658b authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #41 from InfiniTensor/model_scripts_modualization

Model scripts modualization
parent 8bd0f91c
...@@ -4,17 +4,11 @@ from typing import List, Sequence ...@@ -4,17 +4,11 @@ from typing import List, Sequence
from tqdm import tqdm from tqdm import tqdm
from libinfinicore_infer import ( from libinfinicore_infer import (
DeepSeekV3Model,
DeepSeekV3MetaCStruct, DeepSeekV3MetaCStruct,
DeepSeekV3CacheCStruct, DeepSeekV3CacheCStruct,
DataType, DataType,
DeviceType, DeviceType,
create_deepseek_v3_model,
create_deepseek_v3_weights,
create_deepseek_v3_weight_loader,
destroy_deepseek_v3_model,
create_deepseek_v3_cache,
drop_deepseek_v3_cache,
infer_batch_deepseek_v3,
) )
from infer_task import InferTask, KVCache from infer_task import InferTask, KVCache
...@@ -306,9 +300,12 @@ def load_deepseek_weights( ...@@ -306,9 +300,12 @@ def load_deepseek_weights(
model_path: str, model_path: str,
ndev: int, ndev: int,
): ):
weight_loader = create_deepseek_v3_weight_loader() model_instance = DeepSeekV3Model()
weight_loader = model_instance.create_weight_loader()
names = DeepseekR1WeightsNaming() names = DeepseekR1WeightsNaming()
input_embd = load_specific_tensor(model_path, names.input_embd()).to(meta.torch_dtype_logits) input_embd = load_specific_tensor(model_path, names.input_embd()).to(
meta.torch_dtype_logits
)
weight_loader.contents.load_input_embd(weights, input_embd.data_ptr()) weight_loader.contents.load_input_embd(weights, input_embd.data_ptr())
del input_embd del input_embd
...@@ -590,7 +587,9 @@ class DeepSeekV3ForCauslLM: ...@@ -590,7 +587,9 @@ class DeepSeekV3ForCauslLM:
print(model_dir_path) print(model_dir_path)
if "deepseek_v3" == config["model_type"]: if "deepseek_v3" == config["model_type"]:
self.meta = DeepSeekV3Meta(config, max_tokens=max_tokens, dtype=torch.float16) self.meta = DeepSeekV3Meta(
config, max_tokens=max_tokens, dtype=torch.float16
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
else: else:
raise ValueError("Unsupported model architecture") raise ValueError("Unsupported model architecture")
...@@ -598,16 +597,18 @@ class DeepSeekV3ForCauslLM: ...@@ -598,16 +597,18 @@ class DeepSeekV3ForCauslLM:
print(f"Creating model on {ndev} devices...") print(f"Creating model on {ndev} devices...")
load_start_time = time.time() load_start_time = time.time()
dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
weights = create_deepseek_v3_weights(
self.meta, self.model_instance = DeepSeekV3Model()
weights = self.model_instance.create_weights(
byref(self.meta),
device, device,
ndev, ndev,
dev_ids, dev_ids,
) )
# Load weights from host # Load weights from host
# load_deepseek_weights(self.meta, weights, model_dir_path, ndev) load_deepseek_weights(self.meta, weights, model_dir_path, ndev)
# Create model instance # Create model instance
self.model_instance = create_deepseek_v3_model( self.model_ptr = self.model_instance.create_model(
byref(self.meta), byref(self.meta),
weights, weights,
) )
...@@ -618,16 +619,16 @@ class DeepSeekV3ForCauslLM: ...@@ -618,16 +619,16 @@ class DeepSeekV3ForCauslLM:
return self.meta.dctx return self.meta.dctx
def create_kv_cache(self): def create_kv_cache(self):
return create_deepseek_v3_cache(self.model_instance) return self.model_instance.create_cache(self.model_ptr)
def drop_kv_cache(self, kv_cache): def drop_kv_cache(self, kv_cache):
drop_deepseek_v3_cache(self.model_instance, kv_cache) self.model_instance.drop_cache(self.model_ptr, kv_cache)
def batch_infer_one_round(self, tasks: List[InferTask]): def batch_infer_one_round(self, tasks: List[InferTask]):
output = (c_uint * len(tasks))() output = (c_uint * len(tasks))()
batch_inputs = DeepSeekV3BatchedTask(tasks) batch_inputs = DeepSeekV3BatchedTask(tasks)
infer_batch_deepseek_v3( self.model_instance.infer_batch(
self.model_instance, self.model_ptr,
*(batch_inputs.input_args()), *(batch_inputs.input_args()),
output, output,
) )
...@@ -736,7 +737,7 @@ class DeepSeekV3ForCauslLM: ...@@ -736,7 +737,7 @@ class DeepSeekV3ForCauslLM:
# return math.exp(nll / total_len) # return math.exp(nll / total_len)
def destroy_model_instance(self): def destroy_model_instance(self):
destroy_deepseek_v3_model(self.model_instance) self.model_instance.destroy_model(self.model_ptr)
print("Model destroyed") print("Model destroyed")
......
from typing import List, Sequence from typing import List, Sequence
import math
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import torch
import transformers
from sympy import true
from libinfinicore_infer import ( from libinfinicore_infer import (
JiugeModel,
JiugeMetaCStruct, JiugeMetaCStruct,
JiugeWeightsCStruct, JiugeWeightsCStruct,
KVCacheCStruct,
DataType, DataType,
DeviceType, DeviceType,
create_jiuge_model, KVCacheCStruct,
destroy_jiuge_model,
create_kv_cache,
drop_kv_cache,
infer_batch_jiuge,
forward_batch_jiuge,
) )
from infer_task import InferTask, KVCache from infer_task import InferTask, KVCache
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import math
import torch
import transformers
torch.set_default_device("cpu") torch.set_default_device("cpu")
...@@ -419,6 +413,9 @@ class JiugeForCauslLM: ...@@ -419,6 +413,9 @@ class JiugeForCauslLM:
transpose_weight = ( transpose_weight = (
device != DeviceType.DEVICE_TYPE_ASCEND device != DeviceType.DEVICE_TYPE_ASCEND
) # y = xW is faster than y=xW^T on Ascend ) # y = xW is faster than y=xW^T on Ascend
self.jiuge_model = JiugeModel()
if "llama" == config["model_type"]: if "llama" == config["model_type"]:
model = ( model = (
transformers.LlamaForCausalLM.from_pretrained(model_dir_path) transformers.LlamaForCausalLM.from_pretrained(model_dir_path)
...@@ -509,7 +506,8 @@ class JiugeForCauslLM: ...@@ -509,7 +506,8 @@ class JiugeForCauslLM:
self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
self.ndev = ndev self.ndev = ndev
self.device = device self.device = device
self.model_instance = create_jiuge_model(
self.model_instance = self.jiuge_model.create_model(
byref(self.meta), byref(self.meta),
byref(self.weights), byref(self.weights),
device, device,
...@@ -523,7 +521,7 @@ class JiugeForCauslLM: ...@@ -523,7 +521,7 @@ class JiugeForCauslLM:
return self.meta.dctx return self.meta.dctx
def create_kv_cache(self): def create_kv_cache(self):
return create_kv_cache( return self.jiuge_model.create_kv_cache(
self.meta.nlayer, self.meta.nlayer,
self.meta.dctx, self.meta.dctx,
self.meta.nkvh, self.meta.nkvh,
...@@ -536,12 +534,12 @@ class JiugeForCauslLM: ...@@ -536,12 +534,12 @@ class JiugeForCauslLM:
) )
def drop_kv_cache(self, kv_cache): def drop_kv_cache(self, kv_cache):
drop_kv_cache(kv_cache) self.jiuge_model.drop_kv_cache(kv_cache)
def batch_infer_one_round(self, tasks: List[InferTask]): def batch_infer_one_round(self, tasks: List[InferTask]):
output = (c_uint * len(tasks))() output = (c_uint * len(tasks))()
batch_inputs = JiugeBatchedTask(tasks) batch_inputs = JiugeBatchedTask(tasks)
infer_batch_jiuge( self.jiuge_model.infer_batch(
self.model_instance, self.model_instance,
*(batch_inputs.input_args()), *(batch_inputs.input_args()),
output, output,
...@@ -621,7 +619,7 @@ class JiugeForCauslLM: ...@@ -621,7 +619,7 @@ class JiugeForCauslLM:
logits = torch.zeros( logits = torch.zeros(
(batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
) )
forward_batch_jiuge( self.jiuge_model.forward_batch(
self.model_instance, self.model_instance,
batch_inputs.tokens, batch_inputs.tokens,
batch_inputs.ntok, batch_inputs.ntok,
...@@ -651,7 +649,7 @@ class JiugeForCauslLM: ...@@ -651,7 +649,7 @@ class JiugeForCauslLM:
return math.exp(nll / total_len) return math.exp(nll / total_len)
def destroy_model_instance(self): def destroy_model_instance(self):
destroy_jiuge_model(self.model_instance) self.jiuge_model.destroy_model(self.model_instance)
print("Model destroyed") print("Model destroyed")
......
from typing import List, Sequence from typing import List, Sequence
import math
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import torch
import transformers
from libinfinicore_infer import ( from libinfinicore_infer import (
JiugeAWQModel,
JiugeAWQMetaCStruct, JiugeAWQMetaCStruct,
KVCacheCStruct,
DataType, DataType,
DeviceType, DeviceType,
load_model_weight, KVCacheCStruct,
create_jiuge_awq_weights,
create_jiuge_awq_model,
destroy_jiuge_awq_model,
create_kv_cache,
drop_kv_cache,
infer_batch_jiuge_awq,
forward_batch_jiuge_awq,
) )
from infer_task import InferTask, KVCache from infer_task import InferTask, KVCache
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import math
import torch
import transformers
torch.set_default_device("cpu") torch.set_default_device("cpu")
...@@ -160,8 +153,10 @@ class JiugeAWQForCausalLM: ...@@ -160,8 +153,10 @@ class JiugeAWQForCausalLM:
self.device = device self.device = device
self.meta = JiugeAWQMetaFromConfig(config, max_tokens=max_tokens) self.meta = JiugeAWQMetaFromConfig(config, max_tokens=max_tokens)
self.weights = create_jiuge_awq_weights( self.jiuge_awq_model = JiugeAWQModel()
self.meta,
self.weights = self.jiuge_awq_model.create_weights(
byref(self.meta),
self.device, self.device,
ndev, ndev,
self.dev_ids, self.dev_ids,
...@@ -178,12 +173,9 @@ class JiugeAWQForCausalLM: ...@@ -178,12 +173,9 @@ class JiugeAWQForCausalLM:
self.load_all_safetensors_from_dir(os.path.join(model_dir_path)) self.load_all_safetensors_from_dir(os.path.join(model_dir_path))
self.model_instance = create_jiuge_awq_model( self.model_instance = self.jiuge_awq_model.create_model(
self.meta, byref(self.meta),
self.weights, self.weights,
device,
ndev,
self.dev_ids,
) )
load_end_time = time.time() load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s") print(f"Time used: {load_end_time - load_start_time:.3f}s")
...@@ -203,13 +195,15 @@ class JiugeAWQForCausalLM: ...@@ -203,13 +195,15 @@ class JiugeAWQForCausalLM:
tensor = tensor * self.meta.scale_input tensor = tensor * self.meta.scale_input
elif "lm_head.weight" in key: elif "lm_head.weight" in key:
tensor = tensor * self.meta.scale_output tensor = tensor * self.meta.scale_output
load_model_weight(self.weights, key, tensor.data_ptr()) self.jiuge_awq_model.load_weight(
self.weights, key, tensor.data_ptr()
)
def max_context_len(self): def max_context_len(self):
return self.meta.dctx return self.meta.dctx
def create_kv_cache(self): def create_kv_cache(self):
return create_kv_cache( return self.jiuge_awq_model.create_kv_cache(
self.meta.nlayer, self.meta.nlayer,
self.meta.dctx, self.meta.dctx,
self.meta.nkvh, self.meta.nkvh,
...@@ -222,12 +216,12 @@ class JiugeAWQForCausalLM: ...@@ -222,12 +216,12 @@ class JiugeAWQForCausalLM:
) )
def drop_kv_cache(self, kv_cache): def drop_kv_cache(self, kv_cache):
drop_kv_cache(kv_cache) self.jiuge_awq_model.drop_kv_cache(kv_cache)
def batch_infer_one_round(self, tasks: List[InferTask]): def batch_infer_one_round(self, tasks: List[InferTask]):
output = (c_uint * len(tasks))() output = (c_uint * len(tasks))()
batch_inputs = JiugeAWQBatchedTask(tasks) batch_inputs = JiugeAWQBatchedTask(tasks)
infer_batch_jiuge_awq( self.jiuge_awq_model.infer_batch(
self.model_instance, self.model_instance,
*(batch_inputs.input_args()), *(batch_inputs.input_args()),
output, output,
...@@ -308,7 +302,7 @@ class JiugeAWQForCausalLM: ...@@ -308,7 +302,7 @@ class JiugeAWQForCausalLM:
logits = torch.zeros( logits = torch.zeros(
(batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
) )
forward_batch_jiuge_awq( self.jiuge_awq_model.forward_batch(
self.model_instance, self.model_instance,
batch_inputs.tokens, batch_inputs.tokens,
batch_inputs.ntok, batch_inputs.ntok,
...@@ -338,14 +332,14 @@ class JiugeAWQForCausalLM: ...@@ -338,14 +332,14 @@ class JiugeAWQForCausalLM:
return math.exp(nll / total_len) return math.exp(nll / total_len)
def destroy_model_instance(self): def destroy_model_instance(self):
destroy_jiuge_awq_model(self.model_instance) self.jiuge_awq_model.destroy_model(self.model_instance)
print("Model destroyed") print("Model destroyed")
def test(): def test():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print( print(
"Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]" "Usage: python jiuge_awq.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
) )
sys.exit(1) sys.exit(1)
model_path = sys.argv[2] model_path = sys.argv[2]
...@@ -366,7 +360,7 @@ def test(): ...@@ -366,7 +360,7 @@ def test():
device_type = DeviceType.DEVICE_TYPE_ILUVATAR device_type = DeviceType.DEVICE_TYPE_ILUVATAR
else: else:
print( print(
"Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]" "Usage: python main_jiuge_awq.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
) )
sys.exit(1) sys.exit(1)
......
import ctypes
from ctypes import c_char, c_char_p, c_size_t, c_uint, c_int, c_float, c_void_p, POINTER
import os
class DataType(ctypes.c_int):
INFINI_DTYPE_INVALID = 0
INFINI_DTYPE_BYTE = 1
INFINI_DTYPE_BOOL = 2
INFINI_DTYPE_I8 = 3
INFINI_DTYPE_I16 = 4
INFINI_DTYPE_I32 = 5
INFINI_DTYPE_I64 = 6
INFINI_DTYPE_U8 = 7
INFINI_DTYPE_U16 = 8
INFINI_DTYPE_U32 = 9
INFINI_DTYPE_U64 = 10
INFINI_DTYPE_F8 = 11
INFINI_DTYPE_F16 = 12
INFINI_DTYPE_F32 = 13
INFINI_DTYPE_F64 = 14
INFINI_DTYPE_C16 = 15
INFINI_DTYPE_C32 = 16
INFINI_DTYPE_C64 = 17
INFINI_DTYPE_C128 = 18
INFINI_DTYPE_BF16 = 19
class DeviceType(ctypes.c_int):
DEVICE_TYPE_CPU = 0
DEVICE_TYPE_NVIDIA = 1
DEVICE_TYPE_CAMBRICON = 2
DEVICE_TYPE_ASCEND = 3
DEVICE_TYPE_METAX = 4
DEVICE_TYPE_MOORE = 5
DEVICE_TYPE_ILUVATAR = 6
class JiugeMetaCStruct(ctypes.Structure):
_fields_ = [
("dt_logits", DataType),
("nlayer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("dh", c_size_t),
("di", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_uint),
]
# Define the JiugeWeights struct
class JiugeWeightsCStruct(ctypes.Structure):
_fields_ = [
("nlayer", c_size_t),
("dt_norm", DataType),
("dt_mat", DataType),
("transpose_linear_weights", c_int),
("input_embd", c_void_p),
("output_norm", c_void_p),
("output_embd", c_void_p),
("attn_norm", POINTER(c_void_p)),
("attn_qkv", POINTER(c_void_p)),
("attn_qkv_b", POINTER(c_void_p)),
("attn_o", POINTER(c_void_p)),
("ffn_norm", POINTER(c_void_p)),
("ffn_gate_up", POINTER(c_void_p)),
("ffn_down", POINTER(c_void_p)),
]
class JiugeModelCSruct(ctypes.Structure):
pass
class DeepSeekV3MetaCStruct(ctypes.Structure):
_fields_ = [
# dtypes
("dt_logits", DataType),
("dt_norm", DataType),
("dt_quant_weight", DataType),
("dt_quant_scale", DataType),
("dt_quant_zero", DataType),
("dt_gate_weight", DataType),
("dt_gate_bias", DataType),
# sizes
("n_sparse_layer", c_size_t),
("n_dense_layer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("d_rope", c_size_t),
("d_nope", c_size_t),
("r_q", c_size_t),
("r_kv", c_size_t),
("d_qk", c_size_t),
("d_v", c_size_t),
# routing / experts / vocab / ctx
("routed_scale", c_float),
("nexperts", c_size_t),
("kexperts", c_size_t),
("di", c_size_t),
("di_moe", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
# misc
("epsilon", c_float),
("rope_theta", c_float),
("end_token", c_uint),
]
class DeepSeekV3WeightsCStruct(ctypes.Structure):
pass
# void (*load_global_fn)(DeepSeekV3Weights*, void *cpu_ptr)
load_global_fn = ctypes.CFUNCTYPE(None, POINTER(DeepSeekV3WeightsCStruct), c_void_p)
# void (*load_layer_fn)(DeepSeekV3Weights*, void *cpu_ptr, size_t layer_id)
load_layer_fn = ctypes.CFUNCTYPE(
None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_size_t
)
# void (*load_layer_linear_fn)(DeepSeekV3Weights*, void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer_id)
load_layer_linear_fn = ctypes.CFUNCTYPE(
None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_void_p, c_void_p, c_size_t
)
# void (*load_layer_mlp_fn)(DeepSeekV3Weights*, ... , size_t layer_id)
load_layer_mlp_fn = ctypes.CFUNCTYPE(
None,
POINTER(DeepSeekV3WeightsCStruct),
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_size_t,
)
# void (*load_layer_expert_mlp_fn)(DeepSeekV3Weights*, ..., size_t layer_id, size_t expert_id)
load_layer_expert_mlp_fn = ctypes.CFUNCTYPE(
None,
POINTER(DeepSeekV3WeightsCStruct),
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_size_t,
c_size_t,
)
# -------------------------------------------------------------------
# Struct containing all weight loading functions
# -------------------------------------------------------------------
class DeepSeekV3WeightLoaderCStruct(ctypes.Structure):
_fields_ = [
# Global
("load_input_embd", load_global_fn),
("load_output_norm", load_global_fn),
("load_output_embd", load_global_fn),
# Attention
("load_attn_norm", load_layer_fn),
("load_attn_q_a_proj", load_layer_linear_fn),
("load_attn_q_a_layernorm", load_layer_fn),
("load_attn_q_b_proj", load_layer_linear_fn),
("load_attn_kv_a_proj_with_mqa", load_layer_linear_fn),
("load_attn_kv_a_layernorm", load_layer_fn),
("load_attn_kv_b_proj", load_layer_linear_fn),
("load_attn_o_proj", load_layer_linear_fn),
# MLP
("load_mlp_norm", load_layer_fn),
# MLP dense part
("load_mlp_dense", load_layer_mlp_fn),
# MLP sparse gating
("load_mlp_gate_weight", load_layer_fn),
("load_mlp_gate_bias", load_layer_fn),
# Shared experts
("load_mlp_shared_experts", load_layer_mlp_fn),
# Per-expert functions
("load_mlp_experts", load_layer_expert_mlp_fn),
]
class DeepSeekV3ModelCStruct(ctypes.Structure):
pass
class KVCacheCStruct(ctypes.Structure):
pass
class DeepSeekV3CacheCStruct(ctypes.Structure):
pass
class JiugeAWQMetaCStruct(ctypes.Structure):
_fields_ = [
("dt_logits", DataType),
("dt_linear_w", DataType),
("dt_norm_w", DataType),
("nlayer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("dh", c_size_t),
("di", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_uint),
("nbit", c_size_t),
("quant_group_size", c_size_t),
("has_qkv_bias", c_char),
]
class ModelWeightsCStruct(ctypes.Structure):
pass
class JiugeAWQModelCStruct(ctypes.Structure):
pass # opaque struct
def __open_library__():
lib_path = os.path.join(
os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so"
)
lib = ctypes.CDLL(lib_path)
lib.createKVCache.argtypes = [
c_size_t, # nlayers
c_size_t, # max_len
c_size_t, # nkvh_
c_size_t, # dk
c_size_t, # dv
DataType, # dtype
DeviceType, # device
POINTER(c_int), # dev_ids
c_size_t, # ndev
]
lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)]
lib.createJiugeModel.restype = POINTER(JiugeModelCSruct)
lib.createJiugeModel.argtypes = [
POINTER(JiugeMetaCStruct), # JiugeMeta const *
POINTER(JiugeWeightsCStruct), # JiugeWeights const *
DeviceType, # DeviceType
c_int, # int ndev
POINTER(c_int), # int const *dev_ids
]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModelCSruct)]
lib.inferBatchJiuge.restype = None
lib.inferBatchJiuge.argtypes = [
POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok
POINTER(c_uint), # unsigned int const *req_lens
c_uint, # unsigned int nreq
POINTER(c_uint), # unsigned int const *req_pos
POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches
POINTER(c_float), # float temperature
POINTER(c_uint), # unsigned int topk
POINTER(c_float), # float topp
POINTER(c_uint), # unsigned int *output
]
lib.forwardBatchJiuge.restype = None
lib.forwardBatchJiuge.argtypes = [
POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok
POINTER(c_uint), # unsigned int const *req_lens
c_uint, # unsigned int nreq
POINTER(c_uint), # unsigned int const *req_pos
POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches
c_void_p, # void *logits
]
# createDeepSeekV3WeightLoader
lib.createDeepSeekV3WeightLoader.argtypes = []
lib.createDeepSeekV3WeightLoader.restype = POINTER(DeepSeekV3WeightLoaderCStruct)
lib.createDeepSeekV3Weights.argtypes = [
POINTER(DeepSeekV3MetaCStruct),
DeviceType,
c_int,
POINTER(c_int),
]
lib.createDeepSeekV3Weights.restype = POINTER(DeepSeekV3WeightsCStruct)
lib.createDeepSeekV3Model.argtypes = [
POINTER(DeepSeekV3MetaCStruct),
POINTER(DeepSeekV3WeightsCStruct),
]
lib.createDeepSeekV3Model.restype = POINTER(DeepSeekV3ModelCStruct)
# destroyDeepSeekV3Model
lib.destroyDeepSeekV3Model.argtypes = [POINTER(DeepSeekV3ModelCStruct)]
lib.destroyDeepSeekV3Model.restype = None
# createDeepSeekV3Cache
lib.createDeepSeekV3Cache.argtypes = [POINTER(DeepSeekV3ModelCStruct)]
lib.createDeepSeekV3Cache.restype = POINTER(DeepSeekV3CacheCStruct)
# dropDeepSeekV3Cache
lib.dropDeepSeekV3Cache.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(DeepSeekV3CacheCStruct),
]
lib.dropDeepSeekV3Cache.restype = None
# inferBatchDeepSeekV3
lib.inferBatchDeepSeekV3.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(DeepSeekV3CacheCStruct)),
POINTER(c_float),
POINTER(c_uint),
POINTER(c_float),
POINTER(c_uint),
]
lib.inferBatchDeepSeekV3.restype = None
# forwardBatchDeepSeekV3
lib.forwardBatchDeepSeekV3.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(DeepSeekV3CacheCStruct)),
c_void_p,
]
lib.forwardBatchDeepSeekV3.restype = None
lib.createJiugeAWQWeights.restype = POINTER(ModelWeightsCStruct)
lib.createJiugeAWQWeights.argtypes = [
POINTER(JiugeAWQMetaCStruct), # const JiugeAWQMeta*
DeviceType, # infiniDevice_t
c_int, # int ndev
POINTER(c_int), # const int* dev_ids
]
# createJiugeAWQModel
lib.createJiugeAWQModel.restype = POINTER(JiugeAWQModelCStruct)
lib.createJiugeAWQModel.argtypes = [
POINTER(JiugeAWQMetaCStruct), # const JiugeAWQMeta*
POINTER(ModelWeightsCStruct), # const ModelWeights*
]
# destroyJiugeAWQModel
lib.destroyJiugeAWQModel.argtypes = [POINTER(JiugeAWQModelCStruct)]
lib.destroyJiugeAWQModel.restype = None
# inferBatchJiugeAWQ
lib.inferBatchJiugeAWQ.argtypes = [
POINTER(JiugeAWQModelCStruct), # JiugeAWQModel*
POINTER(c_uint), # const uint32_t* tokens
c_uint, # uint32_t ntok
POINTER(c_uint), # const uint32_t* req_lens
c_uint, # uint32_t nreq
POINTER(c_uint), # const uint32_t* req_pos
POINTER(POINTER(KVCacheCStruct)), # struct KVCache** kv_caches
POINTER(c_float), # const float* temperature
POINTER(c_uint), # const uint32_t* topk
POINTER(c_float), # const float* topp
POINTER(c_uint), # uint32_t* output
]
lib.inferBatchJiugeAWQ.restype = None
# forwardBatchJiugeAWQ
lib.forwardBatchJiugeAWQ.argtypes = [
POINTER(JiugeAWQModelCStruct), # JiugeAWQModel*
POINTER(c_uint), # const uint32_t* tokens
c_uint, # uint32_t ntok
POINTER(c_uint), # const uint32_t* req_lens
c_uint, # uint32_t nreq
POINTER(c_uint), # const uint32_t* req_pos
POINTER(POINTER(KVCacheCStruct)), # struct KVCache** kv_caches
c_void_p, # void* logits
]
lib.forwardBatchJiugeAWQ.restype = None
lib.loadModelWeight.argtypes = [
POINTER(ModelWeightsCStruct), # struct ModelWeights*
c_char_p, # const char* name
c_void_p, # void* data
]
lib.loadModelWeight.restype = None
return lib
LIB = __open_library__()
def load_model_weight(weights, name, data):
LIB.loadModelWeight(weights, name.encode("utf-8"), data)
create_jiuge_model = LIB.createJiugeModel
destroy_jiuge_model = LIB.destroyJiugeModel
create_kv_cache = LIB.createKVCache
drop_kv_cache = LIB.dropKVCache
infer_batch_jiuge = LIB.inferBatchJiuge
forward_batch_jiuge = LIB.forwardBatchJiuge
create_jiuge_awq_weights = LIB.createJiugeAWQWeights
create_jiuge_awq_model = LIB.createJiugeAWQModel
destroy_jiuge_awq_model = LIB.destroyJiugeAWQModel
infer_batch_jiuge_awq = LIB.inferBatchJiugeAWQ
forward_batch_jiuge_awq = LIB.forwardBatchJiugeAWQ
create_deepseek_v3_model = LIB.createDeepSeekV3Model
destroy_deepseek_v3_model = LIB.destroyDeepSeekV3Model
create_deepseek_v3_weight_loader = LIB.createDeepSeekV3WeightLoader
create_deepseek_v3_weights = LIB.createDeepSeekV3Weights
create_deepseek_v3_cache = LIB.createDeepSeekV3Cache
drop_deepseek_v3_cache = LIB.dropDeepSeekV3Cache
infer_batch_deepseek_v3 = LIB.inferBatchDeepSeekV3
from .base import DataType, DeviceType, KVCacheCStruct
from .jiuge import JiugeModel, JiugeMetaCStruct, JiugeWeightsCStruct
from .jiuge_awq import JiugeAWQModel, JiugeAWQMetaCStruct, ModelWeightsCStruct
from .deepseek_v3 import (
DeepSeekV3Model,
DeepSeekV3MetaCStruct,
DeepSeekV3WeightsCStruct,
DeepSeekV3WeightLoaderCStruct,
DeepSeekV3CacheCStruct,
)
__all__ = [
"DataType",
"DeviceType",
"KVCacheCStruct",
"JiugeModel",
"JiugeMetaCStruct",
"JiugeWeightsCStruct",
"JiugeAWQModel",
"JiugeAWQMetaCStruct",
"ModelWeightsCStruct",
"DeepSeekV3Model",
"DeepSeekV3MetaCStruct",
"DeepSeekV3WeightsCStruct",
"DeepSeekV3WeightLoaderCStruct",
"ModelRegister",
]
import ctypes
from ctypes import c_char, c_char_p, c_size_t, c_uint, c_int, c_float, c_void_p, POINTER
import os
class DataType(ctypes.c_int):
INFINI_DTYPE_INVALID = 0
INFINI_DTYPE_BYTE = 1
INFINI_DTYPE_BOOL = 2
INFINI_DTYPE_I8 = 3
INFINI_DTYPE_I16 = 4
INFINI_DTYPE_I32 = 5
INFINI_DTYPE_I64 = 6
INFINI_DTYPE_U8 = 7
INFINI_DTYPE_U16 = 8
INFINI_DTYPE_U32 = 9
INFINI_DTYPE_U64 = 10
INFINI_DTYPE_F8 = 11
INFINI_DTYPE_F16 = 12
INFINI_DTYPE_F32 = 13
INFINI_DTYPE_F64 = 14
INFINI_DTYPE_C16 = 15
INFINI_DTYPE_C32 = 16
INFINI_DTYPE_C64 = 17
INFINI_DTYPE_C128 = 18
INFINI_DTYPE_BF16 = 19
class DeviceType(ctypes.c_int):
DEVICE_TYPE_CPU = 0
DEVICE_TYPE_NVIDIA = 1
DEVICE_TYPE_CAMBRICON = 2
DEVICE_TYPE_ASCEND = 3
DEVICE_TYPE_METAX = 4
DEVICE_TYPE_MOORE = 5
DEVICE_TYPE_ILUVATAR = 6
class KVCacheCStruct(ctypes.Structure):
pass
# Model registration system
_model_registry = []
def register_model(model_class):
"""Decorator to register a model class"""
_model_registry.append(model_class)
return model_class
def register_lib_functions(lib):
"""Register all model functions with the library"""
for model_class in _model_registry:
model_class.register_lib(lib)
class BaseModel:
def __init__(self):
self.lib = self._load_library()
register_lib_functions(self.lib)
def _load_library(self):
lib_path = os.path.join(
os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so"
)
return ctypes.CDLL(lib_path)
from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model
from ctypes import (
c_size_t,
c_uint,
c_int,
c_float,
c_void_p,
POINTER,
Structure,
CFUNCTYPE,
)
class DeepSeekV3MetaCStruct(Structure):
_fields_ = [
("dt_logits", DataType),
("dt_norm", DataType),
("dt_quant_weight", DataType),
("dt_quant_scale", DataType),
("dt_quant_zero", DataType),
("dt_gate_weight", DataType),
("dt_gate_bias", DataType),
("n_sparse_layer", c_size_t),
("n_dense_layer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("d_rope", c_size_t),
("d_nope", c_size_t),
("r_q", c_size_t),
("r_kv", c_size_t),
("d_qk", c_size_t),
("d_v", c_size_t),
("routed_scale", c_float),
("nexperts", c_size_t),
("kexperts", c_size_t),
("di", c_size_t),
("di_moe", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
("epsilon", c_float),
("rope_theta", c_float),
("end_token", c_uint),
]
class DeepSeekV3WeightsCStruct(Structure):
pass
class DeepSeekV3ModelCStruct(Structure):
pass
class DeepSeekV3CacheCStruct(Structure):
pass
load_global_fn = CFUNCTYPE(None, POINTER(DeepSeekV3WeightsCStruct), c_void_p)
load_layer_fn = CFUNCTYPE(None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_size_t)
load_layer_linear_fn = CFUNCTYPE(
None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_void_p, c_void_p, c_size_t
)
load_layer_mlp_fn = CFUNCTYPE(
None,
POINTER(DeepSeekV3WeightsCStruct),
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_size_t,
)
load_layer_expert_mlp_fn = CFUNCTYPE(
None,
POINTER(DeepSeekV3WeightsCStruct),
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_size_t,
c_size_t,
)
class DeepSeekV3WeightLoaderCStruct(Structure):
_fields_ = [
("load_input_embd", load_global_fn),
("load_output_norm", load_global_fn),
("load_output_embd", load_global_fn),
("load_attn_norm", load_layer_fn),
("load_attn_q_a_proj", load_layer_linear_fn),
("load_attn_q_a_layernorm", load_layer_fn),
("load_attn_q_b_proj", load_layer_linear_fn),
("load_attn_kv_a_proj_with_mqa", load_layer_linear_fn),
("load_attn_kv_a_layernorm", load_layer_fn),
("load_attn_kv_b_proj", load_layer_linear_fn),
("load_attn_o_proj", load_layer_linear_fn),
("load_mlp_norm", load_layer_fn),
("load_mlp_dense", load_layer_mlp_fn),
("load_mlp_gate_weight", load_layer_fn),
("load_mlp_gate_bias", load_layer_fn),
("load_mlp_shared_experts", load_layer_mlp_fn),
("load_mlp_experts", load_layer_expert_mlp_fn),
]
@register_model
class DeepSeekV3Model(BaseModel):
@classmethod
def register_lib(cls, lib):
"""Register DeepSeekV3 model functions with the library"""
lib.createDeepSeekV3WeightLoader.argtypes = []
lib.createDeepSeekV3WeightLoader.restype = POINTER(
DeepSeekV3WeightLoaderCStruct
)
lib.createDeepSeekV3Weights.argtypes = [
POINTER(DeepSeekV3MetaCStruct),
DeviceType,
c_int,
POINTER(c_int),
]
lib.createDeepSeekV3Weights.restype = POINTER(DeepSeekV3WeightsCStruct)
lib.createDeepSeekV3Model.argtypes = [
POINTER(DeepSeekV3MetaCStruct),
POINTER(DeepSeekV3WeightsCStruct),
]
lib.createDeepSeekV3Model.restype = POINTER(DeepSeekV3ModelCStruct)
lib.destroyDeepSeekV3Model.argtypes = [POINTER(DeepSeekV3ModelCStruct)]
lib.createDeepSeekV3Cache.argtypes = [POINTER(DeepSeekV3ModelCStruct)]
lib.createDeepSeekV3Cache.restype = POINTER(DeepSeekV3CacheCStruct)
lib.dropDeepSeekV3Cache.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(DeepSeekV3CacheCStruct),
]
lib.inferBatchDeepSeekV3.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(DeepSeekV3CacheCStruct)),
POINTER(c_float),
POINTER(c_uint),
POINTER(c_float),
POINTER(c_uint),
]
def create_weight_loader(self):
return self.lib.createDeepSeekV3WeightLoader()
def create_weights(self, meta, device_type, ndev, dev_ids):
return self.lib.createDeepSeekV3Weights(meta, device_type, ndev, dev_ids)
def create_model(self, meta, weights):
return self.lib.createDeepSeekV3Model(meta, weights)
def destroy_model(self, model):
self.lib.destroyDeepSeekV3Model(model)
def create_cache(self, model):
return self.lib.createDeepSeekV3Cache(model)
def drop_cache(self, model, cache):
self.lib.dropDeepSeekV3Cache(model, cache)
def infer_batch(
self,
model,
tokens,
ntok,
req_lens,
nreq,
req_pos,
caches,
temperature,
topk,
topp,
output,
):
self.lib.inferBatchDeepSeekV3(
model,
tokens,
ntok,
req_lens,
nreq,
req_pos,
caches,
temperature,
topk,
topp,
output,
)
from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model
from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER, Structure, byref
class JiugeMetaCStruct(Structure):
_fields_ = [
("dt_logits", DataType),
("nlayer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("dh", c_size_t),
("di", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_uint),
]
class JiugeWeightsCStruct(Structure):
_fields_ = [
("nlayer", c_size_t),
("dt_norm", DataType),
("dt_mat", DataType),
("transpose_linear_weights", c_int),
("input_embd", c_void_p),
("output_norm", c_void_p),
("output_embd", c_void_p),
("attn_norm", POINTER(c_void_p)),
("attn_qkv", POINTER(c_void_p)),
("attn_qkv_b", POINTER(c_void_p)),
("attn_o", POINTER(c_void_p)),
("ffn_norm", POINTER(c_void_p)),
("ffn_gate_up", POINTER(c_void_p)),
("ffn_down", POINTER(c_void_p)),
]
class JiugeModelCStruct(Structure):
pass
@register_model
class JiugeModel(BaseModel):
@classmethod
def register_lib(cls, lib):
lib.createJiugeModel.restype = POINTER(JiugeModelCStruct)
lib.createJiugeModel.argtypes = [
POINTER(JiugeMetaCStruct),
POINTER(JiugeWeightsCStruct),
DeviceType,
c_int,
POINTER(c_int),
]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModelCStruct)]
lib.createKVCache.argtypes = [
c_size_t,
c_size_t,
c_size_t,
c_size_t,
c_size_t,
DataType,
DeviceType,
POINTER(c_int),
c_size_t,
]
lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)]
lib.inferBatchJiuge.argtypes = [
POINTER(JiugeModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(KVCacheCStruct)),
POINTER(c_float),
POINTER(c_uint),
POINTER(c_float),
POINTER(c_uint),
]
lib.forwardBatchJiuge.argtypes = [
POINTER(JiugeModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(KVCacheCStruct)),
c_void_p,
]
def create_model(self, meta, weights, device_type, ndev, dev_ids):
return self.lib.createJiugeModel(meta, weights, device_type, ndev, dev_ids)
def destroy_model(self, model):
self.lib.destroyJiugeModel(model)
def create_kv_cache(
self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev
):
return self.lib.createKVCache(
nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev
)
def drop_kv_cache(self, kv_cache):
self.lib.dropKVCache(kv_cache)
def infer_batch(
self,
model,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
temperature,
topk,
topp,
output,
):
self.lib.inferBatchJiuge(
model,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
temperature,
topk,
topp,
output,
)
def forward_batch(
self, model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits
):
self.lib.forwardBatchJiuge(
model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits
)
from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model
from ctypes import (
c_size_t,
c_uint,
c_int,
c_float,
c_void_p,
POINTER,
Structure,
c_char,
c_char_p,
)
class JiugeAWQMetaCStruct(Structure):
_fields_ = [
("dt_logits", DataType),
("dt_linear_w", DataType),
("dt_norm_w", DataType),
("nlayer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("dh", c_size_t),
("di", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_uint),
("nbit", c_size_t),
("quant_group_size", c_size_t),
("has_qkv_bias", c_char),
]
class ModelWeightsCStruct(Structure):
pass
class JiugeAWQModelCStruct(Structure):
pass
@register_model
class JiugeAWQModel(BaseModel):
@classmethod
def register_lib(cls, lib):
"""Register JiugeAWQ model functions with the library"""
lib.createJiugeAWQWeights.restype = POINTER(ModelWeightsCStruct)
lib.createJiugeAWQWeights.argtypes = [
POINTER(JiugeAWQMetaCStruct),
DeviceType,
c_int,
POINTER(c_int),
]
lib.createJiugeAWQModel.restype = POINTER(JiugeAWQModelCStruct)
lib.createJiugeAWQModel.argtypes = [
POINTER(JiugeAWQMetaCStruct),
POINTER(ModelWeightsCStruct),
]
lib.destroyJiugeAWQModel.argtypes = [POINTER(JiugeAWQModelCStruct)]
lib.createKVCache.argtypes = [
c_size_t,
c_size_t,
c_size_t,
c_size_t,
c_size_t,
DataType,
DeviceType,
POINTER(c_int),
c_size_t,
]
lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)]
lib.inferBatchJiugeAWQ.argtypes = [
POINTER(JiugeAWQModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(KVCacheCStruct)),
POINTER(c_float),
POINTER(c_uint),
POINTER(c_float),
POINTER(c_uint),
]
lib.forwardBatchJiugeAWQ.argtypes = [
POINTER(JiugeAWQModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(KVCacheCStruct)),
c_void_p,
]
lib.loadModelWeight.argtypes = [
POINTER(ModelWeightsCStruct),
c_char_p,
c_void_p,
]
def create_weights(self, meta, device_type, ndev, dev_ids):
return self.lib.createJiugeAWQWeights(meta, device_type, ndev, dev_ids)
def create_model(self, meta, weights):
return self.lib.createJiugeAWQModel(meta, weights)
def destroy_model(self, model):
self.lib.destroyJiugeAWQModel(model)
def create_kv_cache(
self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev
):
return self.lib.createKVCache(
nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev
)
def drop_kv_cache(self, kv_cache):
self.lib.dropKVCache(kv_cache)
def load_weight(self, weights, name, data):
self.lib.loadModelWeight(weights, name.encode("utf-8"), data)
def infer_batch(
self,
model,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
temperature,
topk,
topp,
output,
):
self.lib.inferBatchJiugeAWQ(
model,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
temperature,
topk,
topp,
output,
)
def forward_batch(
self, model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits
):
self.lib.forwardBatchJiugeAWQ(
model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment