Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
3998658b
Unverified
Commit
3998658b
authored
Sep 04, 2025
by
thatPepe
Committed by
GitHub
Sep 04, 2025
Browse files
Merge pull request #41 from InfiniTensor/model_scripts_modualization
Model scripts modualization
parent
8bd0f91c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
689 additions
and
520 deletions
+689
-520
scripts/deepseek.py
scripts/deepseek.py
+21
-20
scripts/jiuge.py
scripts/jiuge.py
+21
-23
scripts/jiuge_awq.py
scripts/jiuge_awq.py
+27
-33
scripts/libinfinicore_infer.py
scripts/libinfinicore_infer.py
+0
-444
scripts/libinfinicore_infer/__init__.py
scripts/libinfinicore_infer/__init__.py
+27
-0
scripts/libinfinicore_infer/base.py
scripts/libinfinicore_infer/base.py
+68
-0
scripts/libinfinicore_infer/deepseek_v3.py
scripts/libinfinicore_infer/deepseek_v3.py
+209
-0
scripts/libinfinicore_infer/jiuge.py
scripts/libinfinicore_infer/jiuge.py
+149
-0
scripts/libinfinicore_infer/jiuge_awq.py
scripts/libinfinicore_infer/jiuge_awq.py
+167
-0
No files found.
scripts/deepseek.py
View file @
3998658b
...
...
@@ -4,17 +4,11 @@ from typing import List, Sequence
from
tqdm
import
tqdm
from
libinfinicore_infer
import
(
DeepSeekV3Model
,
DeepSeekV3MetaCStruct
,
DeepSeekV3CacheCStruct
,
DataType
,
DeviceType
,
create_deepseek_v3_model
,
create_deepseek_v3_weights
,
create_deepseek_v3_weight_loader
,
destroy_deepseek_v3_model
,
create_deepseek_v3_cache
,
drop_deepseek_v3_cache
,
infer_batch_deepseek_v3
,
)
from
infer_task
import
InferTask
,
KVCache
...
...
@@ -306,9 +300,12 @@ def load_deepseek_weights(
model_path
:
str
,
ndev
:
int
,
):
weight_loader
=
create_deepseek_v3_weight_loader
()
model_instance
=
DeepSeekV3Model
()
weight_loader
=
model_instance
.
create_weight_loader
()
names
=
DeepseekR1WeightsNaming
()
input_embd
=
load_specific_tensor
(
model_path
,
names
.
input_embd
()).
to
(
meta
.
torch_dtype_logits
)
input_embd
=
load_specific_tensor
(
model_path
,
names
.
input_embd
()).
to
(
meta
.
torch_dtype_logits
)
weight_loader
.
contents
.
load_input_embd
(
weights
,
input_embd
.
data_ptr
())
del
input_embd
...
...
@@ -590,7 +587,9 @@ class DeepSeekV3ForCauslLM:
print
(
model_dir_path
)
if
"deepseek_v3"
==
config
[
"model_type"
]:
self
.
meta
=
DeepSeekV3Meta
(
config
,
max_tokens
=
max_tokens
,
dtype
=
torch
.
float16
)
self
.
meta
=
DeepSeekV3Meta
(
config
,
max_tokens
=
max_tokens
,
dtype
=
torch
.
float16
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
else
:
raise
ValueError
(
"Unsupported model architecture"
)
...
...
@@ -598,16 +597,18 @@ class DeepSeekV3ForCauslLM:
print
(
f
"Creating model on
{
ndev
}
devices..."
)
load_start_time
=
time
.
time
()
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
weights
=
create_deepseek_v3_weights
(
self
.
meta
,
self
.
model_instance
=
DeepSeekV3Model
()
weights
=
self
.
model_instance
.
create_weights
(
byref
(
self
.
meta
),
device
,
ndev
,
dev_ids
,
)
# Load weights from host
#
load_deepseek_weights(self.meta, weights, model_dir_path, ndev)
load_deepseek_weights
(
self
.
meta
,
weights
,
model_dir_path
,
ndev
)
# Create model instance
self
.
model_instance
=
create_
deepseek_v3_
model
(
self
.
model_ptr
=
self
.
model_instance
.
create_model
(
byref
(
self
.
meta
),
weights
,
)
...
...
@@ -618,16 +619,16 @@ class DeepSeekV3ForCauslLM:
return
self
.
meta
.
dctx
def
create_kv_cache
(
self
):
return
create_deepseek_v3
_cache
(
self
.
model_
instance
)
return
self
.
model_instance
.
create
_cache
(
self
.
model_
ptr
)
def
drop_kv_cache
(
self
,
kv_cache
):
drop_deepseek_v3
_cache
(
self
.
model_
instance
,
kv_cache
)
self
.
model_instance
.
drop
_cache
(
self
.
model_
ptr
,
kv_cache
)
def
batch_infer_one_round
(
self
,
tasks
:
List
[
InferTask
]):
output
=
(
c_uint
*
len
(
tasks
))()
batch_inputs
=
DeepSeekV3BatchedTask
(
tasks
)
infer_batch_deepseek_v3
(
self
.
model_
instance
,
self
.
model_instance
.
infer_batch
(
self
.
model_
ptr
,
*
(
batch_inputs
.
input_args
()),
output
,
)
...
...
@@ -639,7 +640,7 @@ class DeepSeekV3ForCauslLM:
add_generation_prompt
=
True
,
tokenize
=
False
,
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
infer_task
=
InferTask
(
0
,
...
...
@@ -736,7 +737,7 @@ class DeepSeekV3ForCauslLM:
# return math.exp(nll / total_len)
def
destroy_model_instance
(
self
):
destroy_deepseek_v3
_model
(
self
.
model_
instance
)
self
.
model_instance
.
destroy
_model
(
self
.
model_
ptr
)
print
(
"Model destroyed"
)
...
...
scripts/jiuge.py
View file @
3998658b
from
typing
import
List
,
Sequence
import
math
import
os
from
pathlib
import
Path
import
safetensors
import
sys
import
time
import
json
import
torch
import
transformers
from
sympy
import
true
from
libinfinicore_infer
import
(
JiugeModel
,
JiugeMetaCStruct
,
JiugeWeightsCStruct
,
KVCacheCStruct
,
DataType
,
DeviceType
,
create_jiuge_model
,
destroy_jiuge_model
,
create_kv_cache
,
drop_kv_cache
,
infer_batch_jiuge
,
forward_batch_jiuge
,
KVCacheCStruct
,
)
from
infer_task
import
InferTask
,
KVCache
from
ctypes
import
POINTER
,
c_float
,
c_int
,
c_uint
,
c_void_p
,
byref
import
os
from
pathlib
import
Path
import
safetensors
import
sys
import
time
import
json
import
math
import
torch
import
transformers
torch
.
set_default_device
(
"cpu"
)
...
...
@@ -419,6 +413,9 @@ class JiugeForCauslLM:
transpose_weight
=
(
device
!=
DeviceType
.
DEVICE_TYPE_ASCEND
)
# y = xW is faster than y=xW^T on Ascend
self
.
jiuge_model
=
JiugeModel
()
if
"llama"
==
config
[
"model_type"
]:
model
=
(
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
)
...
...
@@ -509,7 +506,8 @@ class JiugeForCauslLM:
self
.
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
self
.
ndev
=
ndev
self
.
device
=
device
self
.
model_instance
=
create_jiuge_model
(
self
.
model_instance
=
self
.
jiuge_model
.
create_model
(
byref
(
self
.
meta
),
byref
(
self
.
weights
),
device
,
...
...
@@ -523,7 +521,7 @@ class JiugeForCauslLM:
return
self
.
meta
.
dctx
def
create_kv_cache
(
self
):
return
create_kv_cache
(
return
self
.
jiuge_model
.
create_kv_cache
(
self
.
meta
.
nlayer
,
self
.
meta
.
dctx
,
self
.
meta
.
nkvh
,
...
...
@@ -536,12 +534,12 @@ class JiugeForCauslLM:
)
def
drop_kv_cache
(
self
,
kv_cache
):
drop_kv_cache
(
kv_cache
)
self
.
jiuge_model
.
drop_kv_cache
(
kv_cache
)
def
batch_infer_one_round
(
self
,
tasks
:
List
[
InferTask
]):
output
=
(
c_uint
*
len
(
tasks
))()
batch_inputs
=
JiugeBatchedTask
(
tasks
)
infer_batch
_jiuge
(
self
.
jiuge_model
.
infer_batch
(
self
.
model_instance
,
*
(
batch_inputs
.
input_args
()),
output
,
...
...
@@ -621,7 +619,7 @@ class JiugeForCauslLM:
logits
=
torch
.
zeros
(
(
batch_inputs
.
ntok
,
self
.
meta
.
dvoc
),
dtype
=
self
.
meta
.
torch_dtype_logits
)
forward_batch
_jiuge
(
self
.
jiuge_model
.
forward_batch
(
self
.
model_instance
,
batch_inputs
.
tokens
,
batch_inputs
.
ntok
,
...
...
@@ -651,7 +649,7 @@ class JiugeForCauslLM:
return
math
.
exp
(
nll
/
total_len
)
def
destroy_model_instance
(
self
):
destroy_
jiuge_
model
(
self
.
model_instance
)
self
.
jiuge_model
.
destroy_model
(
self
.
model_instance
)
print
(
"Model destroyed"
)
...
...
scripts/jiuge_awq.py
View file @
3998658b
from
typing
import
List
,
Sequence
import
math
import
os
from
pathlib
import
Path
import
safetensors
import
sys
import
time
import
json
import
torch
import
transformers
from
libinfinicore_infer
import
(
JiugeAWQModel
,
JiugeAWQMetaCStruct
,
KVCacheCStruct
,
DataType
,
DeviceType
,
load_model_weight
,
create_jiuge_awq_weights
,
create_jiuge_awq_model
,
destroy_jiuge_awq_model
,
create_kv_cache
,
drop_kv_cache
,
infer_batch_jiuge_awq
,
forward_batch_jiuge_awq
,
KVCacheCStruct
,
)
from
infer_task
import
InferTask
,
KVCache
from
ctypes
import
POINTER
,
c_float
,
c_int
,
c_uint
,
c_void_p
,
byref
import
os
from
pathlib
import
Path
import
safetensors
import
sys
import
time
import
json
import
math
import
torch
import
transformers
torch
.
set_default_device
(
"cpu"
)
...
...
@@ -160,8 +153,10 @@ class JiugeAWQForCausalLM:
self
.
device
=
device
self
.
meta
=
JiugeAWQMetaFromConfig
(
config
,
max_tokens
=
max_tokens
)
self
.
weights
=
create_jiuge_awq_weights
(
self
.
meta
,
self
.
jiuge_awq_model
=
JiugeAWQModel
()
self
.
weights
=
self
.
jiuge_awq_model
.
create_weights
(
byref
(
self
.
meta
),
self
.
device
,
ndev
,
self
.
dev_ids
,
...
...
@@ -178,12 +173,9 @@ class JiugeAWQForCausalLM:
self
.
load_all_safetensors_from_dir
(
os
.
path
.
join
(
model_dir_path
))
self
.
model_instance
=
create_
jiuge_awq_model
(
self
.
meta
,
self
.
model_instance
=
self
.
jiuge_awq_model
.
create_model
(
byref
(
self
.
meta
)
,
self
.
weights
,
device
,
ndev
,
self
.
dev_ids
,
)
load_end_time
=
time
.
time
()
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
...
...
@@ -203,13 +195,15 @@ class JiugeAWQForCausalLM:
tensor
=
tensor
*
self
.
meta
.
scale_input
elif
"lm_head.weight"
in
key
:
tensor
=
tensor
*
self
.
meta
.
scale_output
load_model_weight
(
self
.
weights
,
key
,
tensor
.
data_ptr
())
self
.
jiuge_awq_model
.
load_weight
(
self
.
weights
,
key
,
tensor
.
data_ptr
()
)
def
max_context_len
(
self
):
return
self
.
meta
.
dctx
def
create_kv_cache
(
self
):
return
create_kv_cache
(
return
self
.
jiuge_awq_model
.
create_kv_cache
(
self
.
meta
.
nlayer
,
self
.
meta
.
dctx
,
self
.
meta
.
nkvh
,
...
...
@@ -222,12 +216,12 @@ class JiugeAWQForCausalLM:
)
def
drop_kv_cache
(
self
,
kv_cache
):
drop_kv_cache
(
kv_cache
)
self
.
jiuge_awq_model
.
drop_kv_cache
(
kv_cache
)
def
batch_infer_one_round
(
self
,
tasks
:
List
[
InferTask
]):
output
=
(
c_uint
*
len
(
tasks
))()
batch_inputs
=
JiugeAWQBatchedTask
(
tasks
)
infer_batch
_jiuge_awq
(
self
.
jiuge_awq_model
.
infer_batch
(
self
.
model_instance
,
*
(
batch_inputs
.
input_args
()),
output
,
...
...
@@ -308,7 +302,7 @@ class JiugeAWQForCausalLM:
logits
=
torch
.
zeros
(
(
batch_inputs
.
ntok
,
self
.
meta
.
dvoc
),
dtype
=
self
.
meta
.
torch_dtype_logits
)
forward_batch
_jiuge_awq
(
self
.
jiuge_awq_model
.
forward_batch
(
self
.
model_instance
,
batch_inputs
.
tokens
,
batch_inputs
.
ntok
,
...
...
@@ -338,14 +332,14 @@ class JiugeAWQForCausalLM:
return
math
.
exp
(
nll
/
total_len
)
def
destroy_model_instance
(
self
):
destroy_
jiuge_awq_model
(
self
.
model_instance
)
self
.
jiuge_awq_model
.
destroy_model
(
self
.
model_instance
)
print
(
"Model destroyed"
)
def
test
():
if
len
(
sys
.
argv
)
<
3
:
print
(
"Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
"Usage: python jiuge
_awq
.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys
.
exit
(
1
)
model_path
=
sys
.
argv
[
2
]
...
...
@@ -366,7 +360,7 @@ def test():
device_type
=
DeviceType
.
DEVICE_TYPE_ILUVATAR
else
:
print
(
"Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
"Usage: python
main_
jiuge
_awq
.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys
.
exit
(
1
)
...
...
scripts/libinfinicore_infer.py
deleted
100644 → 0
View file @
8bd0f91c
import
ctypes
from
ctypes
import
c_char
,
c_char_p
,
c_size_t
,
c_uint
,
c_int
,
c_float
,
c_void_p
,
POINTER
import
os
class
DataType
(
ctypes
.
c_int
):
INFINI_DTYPE_INVALID
=
0
INFINI_DTYPE_BYTE
=
1
INFINI_DTYPE_BOOL
=
2
INFINI_DTYPE_I8
=
3
INFINI_DTYPE_I16
=
4
INFINI_DTYPE_I32
=
5
INFINI_DTYPE_I64
=
6
INFINI_DTYPE_U8
=
7
INFINI_DTYPE_U16
=
8
INFINI_DTYPE_U32
=
9
INFINI_DTYPE_U64
=
10
INFINI_DTYPE_F8
=
11
INFINI_DTYPE_F16
=
12
INFINI_DTYPE_F32
=
13
INFINI_DTYPE_F64
=
14
INFINI_DTYPE_C16
=
15
INFINI_DTYPE_C32
=
16
INFINI_DTYPE_C64
=
17
INFINI_DTYPE_C128
=
18
INFINI_DTYPE_BF16
=
19
class
DeviceType
(
ctypes
.
c_int
):
DEVICE_TYPE_CPU
=
0
DEVICE_TYPE_NVIDIA
=
1
DEVICE_TYPE_CAMBRICON
=
2
DEVICE_TYPE_ASCEND
=
3
DEVICE_TYPE_METAX
=
4
DEVICE_TYPE_MOORE
=
5
DEVICE_TYPE_ILUVATAR
=
6
class
JiugeMetaCStruct
(
ctypes
.
Structure
):
_fields_
=
[
(
"dt_logits"
,
DataType
),
(
"nlayer"
,
c_size_t
),
(
"d"
,
c_size_t
),
(
"nh"
,
c_size_t
),
(
"nkvh"
,
c_size_t
),
(
"dh"
,
c_size_t
),
(
"di"
,
c_size_t
),
(
"dctx"
,
c_size_t
),
(
"dvoc"
,
c_size_t
),
(
"epsilon"
,
c_float
),
(
"theta"
,
c_float
),
(
"end_token"
,
c_uint
),
]
# Define the JiugeWeights struct
class
JiugeWeightsCStruct
(
ctypes
.
Structure
):
_fields_
=
[
(
"nlayer"
,
c_size_t
),
(
"dt_norm"
,
DataType
),
(
"dt_mat"
,
DataType
),
(
"transpose_linear_weights"
,
c_int
),
(
"input_embd"
,
c_void_p
),
(
"output_norm"
,
c_void_p
),
(
"output_embd"
,
c_void_p
),
(
"attn_norm"
,
POINTER
(
c_void_p
)),
(
"attn_qkv"
,
POINTER
(
c_void_p
)),
(
"attn_qkv_b"
,
POINTER
(
c_void_p
)),
(
"attn_o"
,
POINTER
(
c_void_p
)),
(
"ffn_norm"
,
POINTER
(
c_void_p
)),
(
"ffn_gate_up"
,
POINTER
(
c_void_p
)),
(
"ffn_down"
,
POINTER
(
c_void_p
)),
]
class
JiugeModelCSruct
(
ctypes
.
Structure
):
pass
class
DeepSeekV3MetaCStruct
(
ctypes
.
Structure
):
_fields_
=
[
# dtypes
(
"dt_logits"
,
DataType
),
(
"dt_norm"
,
DataType
),
(
"dt_quant_weight"
,
DataType
),
(
"dt_quant_scale"
,
DataType
),
(
"dt_quant_zero"
,
DataType
),
(
"dt_gate_weight"
,
DataType
),
(
"dt_gate_bias"
,
DataType
),
# sizes
(
"n_sparse_layer"
,
c_size_t
),
(
"n_dense_layer"
,
c_size_t
),
(
"d"
,
c_size_t
),
(
"nh"
,
c_size_t
),
(
"nkvh"
,
c_size_t
),
(
"d_rope"
,
c_size_t
),
(
"d_nope"
,
c_size_t
),
(
"r_q"
,
c_size_t
),
(
"r_kv"
,
c_size_t
),
(
"d_qk"
,
c_size_t
),
(
"d_v"
,
c_size_t
),
# routing / experts / vocab / ctx
(
"routed_scale"
,
c_float
),
(
"nexperts"
,
c_size_t
),
(
"kexperts"
,
c_size_t
),
(
"di"
,
c_size_t
),
(
"di_moe"
,
c_size_t
),
(
"dctx"
,
c_size_t
),
(
"dvoc"
,
c_size_t
),
# misc
(
"epsilon"
,
c_float
),
(
"rope_theta"
,
c_float
),
(
"end_token"
,
c_uint
),
]
class
DeepSeekV3WeightsCStruct
(
ctypes
.
Structure
):
pass
# void (*load_global_fn)(DeepSeekV3Weights*, void *cpu_ptr)
load_global_fn
=
ctypes
.
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
)
# void (*load_layer_fn)(DeepSeekV3Weights*, void *cpu_ptr, size_t layer_id)
load_layer_fn
=
ctypes
.
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
,
c_size_t
)
# void (*load_layer_linear_fn)(DeepSeekV3Weights*, void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer_id)
load_layer_linear_fn
=
ctypes
.
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
,
c_void_p
,
c_void_p
,
c_size_t
)
# void (*load_layer_mlp_fn)(DeepSeekV3Weights*, ... , size_t layer_id)
load_layer_mlp_fn
=
ctypes
.
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_size_t
,
)
# void (*load_layer_expert_mlp_fn)(DeepSeekV3Weights*, ..., size_t layer_id, size_t expert_id)
load_layer_expert_mlp_fn
=
ctypes
.
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_size_t
,
c_size_t
,
)
# -------------------------------------------------------------------
# Struct containing all weight loading functions
# -------------------------------------------------------------------
class
DeepSeekV3WeightLoaderCStruct
(
ctypes
.
Structure
):
_fields_
=
[
# Global
(
"load_input_embd"
,
load_global_fn
),
(
"load_output_norm"
,
load_global_fn
),
(
"load_output_embd"
,
load_global_fn
),
# Attention
(
"load_attn_norm"
,
load_layer_fn
),
(
"load_attn_q_a_proj"
,
load_layer_linear_fn
),
(
"load_attn_q_a_layernorm"
,
load_layer_fn
),
(
"load_attn_q_b_proj"
,
load_layer_linear_fn
),
(
"load_attn_kv_a_proj_with_mqa"
,
load_layer_linear_fn
),
(
"load_attn_kv_a_layernorm"
,
load_layer_fn
),
(
"load_attn_kv_b_proj"
,
load_layer_linear_fn
),
(
"load_attn_o_proj"
,
load_layer_linear_fn
),
# MLP
(
"load_mlp_norm"
,
load_layer_fn
),
# MLP dense part
(
"load_mlp_dense"
,
load_layer_mlp_fn
),
# MLP sparse gating
(
"load_mlp_gate_weight"
,
load_layer_fn
),
(
"load_mlp_gate_bias"
,
load_layer_fn
),
# Shared experts
(
"load_mlp_shared_experts"
,
load_layer_mlp_fn
),
# Per-expert functions
(
"load_mlp_experts"
,
load_layer_expert_mlp_fn
),
]
class
DeepSeekV3ModelCStruct
(
ctypes
.
Structure
):
pass
class
KVCacheCStruct
(
ctypes
.
Structure
):
pass
class
DeepSeekV3CacheCStruct
(
ctypes
.
Structure
):
pass
class
JiugeAWQMetaCStruct
(
ctypes
.
Structure
):
_fields_
=
[
(
"dt_logits"
,
DataType
),
(
"dt_linear_w"
,
DataType
),
(
"dt_norm_w"
,
DataType
),
(
"nlayer"
,
c_size_t
),
(
"d"
,
c_size_t
),
(
"nh"
,
c_size_t
),
(
"nkvh"
,
c_size_t
),
(
"dh"
,
c_size_t
),
(
"di"
,
c_size_t
),
(
"dctx"
,
c_size_t
),
(
"dvoc"
,
c_size_t
),
(
"epsilon"
,
c_float
),
(
"theta"
,
c_float
),
(
"end_token"
,
c_uint
),
(
"nbit"
,
c_size_t
),
(
"quant_group_size"
,
c_size_t
),
(
"has_qkv_bias"
,
c_char
),
]
class
ModelWeightsCStruct
(
ctypes
.
Structure
):
pass
class
JiugeAWQModelCStruct
(
ctypes
.
Structure
):
pass
# opaque struct
def
__open_library__
():
lib_path
=
os
.
path
.
join
(
os
.
environ
.
get
(
"INFINI_ROOT"
),
"lib"
,
"libinfinicore_infer.so"
)
lib
=
ctypes
.
CDLL
(
lib_path
)
lib
.
createKVCache
.
argtypes
=
[
c_size_t
,
# nlayers
c_size_t
,
# max_len
c_size_t
,
# nkvh_
c_size_t
,
# dk
c_size_t
,
# dv
DataType
,
# dtype
DeviceType
,
# device
POINTER
(
c_int
),
# dev_ids
c_size_t
,
# ndev
]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCacheCStruct
)
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
KVCacheCStruct
)]
lib
.
createJiugeModel
.
restype
=
POINTER
(
JiugeModelCSruct
)
lib
.
createJiugeModel
.
argtypes
=
[
POINTER
(
JiugeMetaCStruct
),
# JiugeMeta const *
POINTER
(
JiugeWeightsCStruct
),
# JiugeWeights const *
DeviceType
,
# DeviceType
c_int
,
# int ndev
POINTER
(
c_int
),
# int const *dev_ids
]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModelCSruct
)]
lib
.
inferBatchJiuge
.
restype
=
None
lib
.
inferBatchJiuge
.
argtypes
=
[
POINTER
(
JiugeModelCSruct
),
# struct JiugeModel const *
POINTER
(
c_uint
),
# unsigned int const *tokens
c_uint
,
# unsigned int ntok
POINTER
(
c_uint
),
# unsigned int const *req_lens
c_uint
,
# unsigned int nreq
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
POINTER
(
KVCacheCStruct
)),
# struct KVCache **kv_caches
POINTER
(
c_float
),
# float temperature
POINTER
(
c_uint
),
# unsigned int topk
POINTER
(
c_float
),
# float topp
POINTER
(
c_uint
),
# unsigned int *output
]
lib
.
forwardBatchJiuge
.
restype
=
None
lib
.
forwardBatchJiuge
.
argtypes
=
[
POINTER
(
JiugeModelCSruct
),
# struct JiugeModel const *
POINTER
(
c_uint
),
# unsigned int const *tokens
c_uint
,
# unsigned int ntok
POINTER
(
c_uint
),
# unsigned int const *req_lens
c_uint
,
# unsigned int nreq
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
POINTER
(
KVCacheCStruct
)),
# struct KVCache **kv_caches
c_void_p
,
# void *logits
]
# createDeepSeekV3WeightLoader
lib
.
createDeepSeekV3WeightLoader
.
argtypes
=
[]
lib
.
createDeepSeekV3WeightLoader
.
restype
=
POINTER
(
DeepSeekV3WeightLoaderCStruct
)
lib
.
createDeepSeekV3Weights
.
argtypes
=
[
POINTER
(
DeepSeekV3MetaCStruct
),
DeviceType
,
c_int
,
POINTER
(
c_int
),
]
lib
.
createDeepSeekV3Weights
.
restype
=
POINTER
(
DeepSeekV3WeightsCStruct
)
lib
.
createDeepSeekV3Model
.
argtypes
=
[
POINTER
(
DeepSeekV3MetaCStruct
),
POINTER
(
DeepSeekV3WeightsCStruct
),
]
lib
.
createDeepSeekV3Model
.
restype
=
POINTER
(
DeepSeekV3ModelCStruct
)
# destroyDeepSeekV3Model
lib
.
destroyDeepSeekV3Model
.
argtypes
=
[
POINTER
(
DeepSeekV3ModelCStruct
)]
lib
.
destroyDeepSeekV3Model
.
restype
=
None
# createDeepSeekV3Cache
lib
.
createDeepSeekV3Cache
.
argtypes
=
[
POINTER
(
DeepSeekV3ModelCStruct
)]
lib
.
createDeepSeekV3Cache
.
restype
=
POINTER
(
DeepSeekV3CacheCStruct
)
# dropDeepSeekV3Cache
lib
.
dropDeepSeekV3Cache
.
argtypes
=
[
POINTER
(
DeepSeekV3ModelCStruct
),
POINTER
(
DeepSeekV3CacheCStruct
),
]
lib
.
dropDeepSeekV3Cache
.
restype
=
None
# inferBatchDeepSeekV3
lib
.
inferBatchDeepSeekV3
.
argtypes
=
[
POINTER
(
DeepSeekV3ModelCStruct
),
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
POINTER
(
POINTER
(
DeepSeekV3CacheCStruct
)),
POINTER
(
c_float
),
POINTER
(
c_uint
),
POINTER
(
c_float
),
POINTER
(
c_uint
),
]
lib
.
inferBatchDeepSeekV3
.
restype
=
None
# forwardBatchDeepSeekV3
lib
.
forwardBatchDeepSeekV3
.
argtypes
=
[
POINTER
(
DeepSeekV3ModelCStruct
),
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
POINTER
(
POINTER
(
DeepSeekV3CacheCStruct
)),
c_void_p
,
]
lib
.
forwardBatchDeepSeekV3
.
restype
=
None
lib
.
createJiugeAWQWeights
.
restype
=
POINTER
(
ModelWeightsCStruct
)
lib
.
createJiugeAWQWeights
.
argtypes
=
[
POINTER
(
JiugeAWQMetaCStruct
),
# const JiugeAWQMeta*
DeviceType
,
# infiniDevice_t
c_int
,
# int ndev
POINTER
(
c_int
),
# const int* dev_ids
]
# createJiugeAWQModel
lib
.
createJiugeAWQModel
.
restype
=
POINTER
(
JiugeAWQModelCStruct
)
lib
.
createJiugeAWQModel
.
argtypes
=
[
POINTER
(
JiugeAWQMetaCStruct
),
# const JiugeAWQMeta*
POINTER
(
ModelWeightsCStruct
),
# const ModelWeights*
]
# destroyJiugeAWQModel
lib
.
destroyJiugeAWQModel
.
argtypes
=
[
POINTER
(
JiugeAWQModelCStruct
)]
lib
.
destroyJiugeAWQModel
.
restype
=
None
# inferBatchJiugeAWQ
lib
.
inferBatchJiugeAWQ
.
argtypes
=
[
POINTER
(
JiugeAWQModelCStruct
),
# JiugeAWQModel*
POINTER
(
c_uint
),
# const uint32_t* tokens
c_uint
,
# uint32_t ntok
POINTER
(
c_uint
),
# const uint32_t* req_lens
c_uint
,
# uint32_t nreq
POINTER
(
c_uint
),
# const uint32_t* req_pos
POINTER
(
POINTER
(
KVCacheCStruct
)),
# struct KVCache** kv_caches
POINTER
(
c_float
),
# const float* temperature
POINTER
(
c_uint
),
# const uint32_t* topk
POINTER
(
c_float
),
# const float* topp
POINTER
(
c_uint
),
# uint32_t* output
]
lib
.
inferBatchJiugeAWQ
.
restype
=
None
# forwardBatchJiugeAWQ
lib
.
forwardBatchJiugeAWQ
.
argtypes
=
[
POINTER
(
JiugeAWQModelCStruct
),
# JiugeAWQModel*
POINTER
(
c_uint
),
# const uint32_t* tokens
c_uint
,
# uint32_t ntok
POINTER
(
c_uint
),
# const uint32_t* req_lens
c_uint
,
# uint32_t nreq
POINTER
(
c_uint
),
# const uint32_t* req_pos
POINTER
(
POINTER
(
KVCacheCStruct
)),
# struct KVCache** kv_caches
c_void_p
,
# void* logits
]
lib
.
forwardBatchJiugeAWQ
.
restype
=
None
lib
.
loadModelWeight
.
argtypes
=
[
POINTER
(
ModelWeightsCStruct
),
# struct ModelWeights*
c_char_p
,
# const char* name
c_void_p
,
# void* data
]
lib
.
loadModelWeight
.
restype
=
None
return
lib
LIB
=
__open_library__
()
def
load_model_weight
(
weights
,
name
,
data
):
LIB
.
loadModelWeight
(
weights
,
name
.
encode
(
"utf-8"
),
data
)
create_jiuge_model
=
LIB
.
createJiugeModel
destroy_jiuge_model
=
LIB
.
destroyJiugeModel
create_kv_cache
=
LIB
.
createKVCache
drop_kv_cache
=
LIB
.
dropKVCache
infer_batch_jiuge
=
LIB
.
inferBatchJiuge
forward_batch_jiuge
=
LIB
.
forwardBatchJiuge
create_jiuge_awq_weights
=
LIB
.
createJiugeAWQWeights
create_jiuge_awq_model
=
LIB
.
createJiugeAWQModel
destroy_jiuge_awq_model
=
LIB
.
destroyJiugeAWQModel
infer_batch_jiuge_awq
=
LIB
.
inferBatchJiugeAWQ
forward_batch_jiuge_awq
=
LIB
.
forwardBatchJiugeAWQ
create_deepseek_v3_model
=
LIB
.
createDeepSeekV3Model
destroy_deepseek_v3_model
=
LIB
.
destroyDeepSeekV3Model
create_deepseek_v3_weight_loader
=
LIB
.
createDeepSeekV3WeightLoader
create_deepseek_v3_weights
=
LIB
.
createDeepSeekV3Weights
create_deepseek_v3_cache
=
LIB
.
createDeepSeekV3Cache
drop_deepseek_v3_cache
=
LIB
.
dropDeepSeekV3Cache
infer_batch_deepseek_v3
=
LIB
.
inferBatchDeepSeekV3
scripts/libinfinicore_infer/__init__.py
0 → 100644
View file @
3998658b
from
.base
import
DataType
,
DeviceType
,
KVCacheCStruct
from
.jiuge
import
JiugeModel
,
JiugeMetaCStruct
,
JiugeWeightsCStruct
from
.jiuge_awq
import
JiugeAWQModel
,
JiugeAWQMetaCStruct
,
ModelWeightsCStruct
from
.deepseek_v3
import
(
DeepSeekV3Model
,
DeepSeekV3MetaCStruct
,
DeepSeekV3WeightsCStruct
,
DeepSeekV3WeightLoaderCStruct
,
DeepSeekV3CacheCStruct
,
)
__all__
=
[
"DataType"
,
"DeviceType"
,
"KVCacheCStruct"
,
"JiugeModel"
,
"JiugeMetaCStruct"
,
"JiugeWeightsCStruct"
,
"JiugeAWQModel"
,
"JiugeAWQMetaCStruct"
,
"ModelWeightsCStruct"
,
"DeepSeekV3Model"
,
"DeepSeekV3MetaCStruct"
,
"DeepSeekV3WeightsCStruct"
,
"DeepSeekV3WeightLoaderCStruct"
,
"ModelRegister"
,
]
scripts/libinfinicore_infer/base.py
0 → 100644
View file @
3998658b
import
ctypes
from
ctypes
import
c_char
,
c_char_p
,
c_size_t
,
c_uint
,
c_int
,
c_float
,
c_void_p
,
POINTER
import
os
class
DataType
(
ctypes
.
c_int
):
INFINI_DTYPE_INVALID
=
0
INFINI_DTYPE_BYTE
=
1
INFINI_DTYPE_BOOL
=
2
INFINI_DTYPE_I8
=
3
INFINI_DTYPE_I16
=
4
INFINI_DTYPE_I32
=
5
INFINI_DTYPE_I64
=
6
INFINI_DTYPE_U8
=
7
INFINI_DTYPE_U16
=
8
INFINI_DTYPE_U32
=
9
INFINI_DTYPE_U64
=
10
INFINI_DTYPE_F8
=
11
INFINI_DTYPE_F16
=
12
INFINI_DTYPE_F32
=
13
INFINI_DTYPE_F64
=
14
INFINI_DTYPE_C16
=
15
INFINI_DTYPE_C32
=
16
INFINI_DTYPE_C64
=
17
INFINI_DTYPE_C128
=
18
INFINI_DTYPE_BF16
=
19
class
DeviceType
(
ctypes
.
c_int
):
DEVICE_TYPE_CPU
=
0
DEVICE_TYPE_NVIDIA
=
1
DEVICE_TYPE_CAMBRICON
=
2
DEVICE_TYPE_ASCEND
=
3
DEVICE_TYPE_METAX
=
4
DEVICE_TYPE_MOORE
=
5
DEVICE_TYPE_ILUVATAR
=
6
class
KVCacheCStruct
(
ctypes
.
Structure
):
pass
# Model registration system
_model_registry
=
[]
def
register_model
(
model_class
):
"""Decorator to register a model class"""
_model_registry
.
append
(
model_class
)
return
model_class
def
register_lib_functions
(
lib
):
"""Register all model functions with the library"""
for
model_class
in
_model_registry
:
model_class
.
register_lib
(
lib
)
class
BaseModel
:
def
__init__
(
self
):
self
.
lib
=
self
.
_load_library
()
register_lib_functions
(
self
.
lib
)
def
_load_library
(
self
):
lib_path
=
os
.
path
.
join
(
os
.
environ
.
get
(
"INFINI_ROOT"
),
"lib"
,
"libinfinicore_infer.so"
)
return
ctypes
.
CDLL
(
lib_path
)
scripts/libinfinicore_infer/deepseek_v3.py
0 → 100644
View file @
3998658b
from
.base
import
BaseModel
,
DataType
,
DeviceType
,
KVCacheCStruct
,
register_model
from
ctypes
import
(
c_size_t
,
c_uint
,
c_int
,
c_float
,
c_void_p
,
POINTER
,
Structure
,
CFUNCTYPE
,
)
class
DeepSeekV3MetaCStruct
(
Structure
):
_fields_
=
[
(
"dt_logits"
,
DataType
),
(
"dt_norm"
,
DataType
),
(
"dt_quant_weight"
,
DataType
),
(
"dt_quant_scale"
,
DataType
),
(
"dt_quant_zero"
,
DataType
),
(
"dt_gate_weight"
,
DataType
),
(
"dt_gate_bias"
,
DataType
),
(
"n_sparse_layer"
,
c_size_t
),
(
"n_dense_layer"
,
c_size_t
),
(
"d"
,
c_size_t
),
(
"nh"
,
c_size_t
),
(
"nkvh"
,
c_size_t
),
(
"d_rope"
,
c_size_t
),
(
"d_nope"
,
c_size_t
),
(
"r_q"
,
c_size_t
),
(
"r_kv"
,
c_size_t
),
(
"d_qk"
,
c_size_t
),
(
"d_v"
,
c_size_t
),
(
"routed_scale"
,
c_float
),
(
"nexperts"
,
c_size_t
),
(
"kexperts"
,
c_size_t
),
(
"di"
,
c_size_t
),
(
"di_moe"
,
c_size_t
),
(
"dctx"
,
c_size_t
),
(
"dvoc"
,
c_size_t
),
(
"epsilon"
,
c_float
),
(
"rope_theta"
,
c_float
),
(
"end_token"
,
c_uint
),
]
class
DeepSeekV3WeightsCStruct
(
Structure
):
pass
class
DeepSeekV3ModelCStruct
(
Structure
):
pass
class
DeepSeekV3CacheCStruct
(
Structure
):
pass
load_global_fn
=
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
)
load_layer_fn
=
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
,
c_size_t
)
load_layer_linear_fn
=
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
,
c_void_p
,
c_void_p
,
c_size_t
)
load_layer_mlp_fn
=
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_size_t
,
)
load_layer_expert_mlp_fn
=
CFUNCTYPE
(
None
,
POINTER
(
DeepSeekV3WeightsCStruct
),
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_size_t
,
c_size_t
,
)
class
DeepSeekV3WeightLoaderCStruct
(
Structure
):
_fields_
=
[
(
"load_input_embd"
,
load_global_fn
),
(
"load_output_norm"
,
load_global_fn
),
(
"load_output_embd"
,
load_global_fn
),
(
"load_attn_norm"
,
load_layer_fn
),
(
"load_attn_q_a_proj"
,
load_layer_linear_fn
),
(
"load_attn_q_a_layernorm"
,
load_layer_fn
),
(
"load_attn_q_b_proj"
,
load_layer_linear_fn
),
(
"load_attn_kv_a_proj_with_mqa"
,
load_layer_linear_fn
),
(
"load_attn_kv_a_layernorm"
,
load_layer_fn
),
(
"load_attn_kv_b_proj"
,
load_layer_linear_fn
),
(
"load_attn_o_proj"
,
load_layer_linear_fn
),
(
"load_mlp_norm"
,
load_layer_fn
),
(
"load_mlp_dense"
,
load_layer_mlp_fn
),
(
"load_mlp_gate_weight"
,
load_layer_fn
),
(
"load_mlp_gate_bias"
,
load_layer_fn
),
(
"load_mlp_shared_experts"
,
load_layer_mlp_fn
),
(
"load_mlp_experts"
,
load_layer_expert_mlp_fn
),
]
@
register_model
class
DeepSeekV3Model
(
BaseModel
):
@
classmethod
def
register_lib
(
cls
,
lib
):
"""Register DeepSeekV3 model functions with the library"""
lib
.
createDeepSeekV3WeightLoader
.
argtypes
=
[]
lib
.
createDeepSeekV3WeightLoader
.
restype
=
POINTER
(
DeepSeekV3WeightLoaderCStruct
)
lib
.
createDeepSeekV3Weights
.
argtypes
=
[
POINTER
(
DeepSeekV3MetaCStruct
),
DeviceType
,
c_int
,
POINTER
(
c_int
),
]
lib
.
createDeepSeekV3Weights
.
restype
=
POINTER
(
DeepSeekV3WeightsCStruct
)
lib
.
createDeepSeekV3Model
.
argtypes
=
[
POINTER
(
DeepSeekV3MetaCStruct
),
POINTER
(
DeepSeekV3WeightsCStruct
),
]
lib
.
createDeepSeekV3Model
.
restype
=
POINTER
(
DeepSeekV3ModelCStruct
)
lib
.
destroyDeepSeekV3Model
.
argtypes
=
[
POINTER
(
DeepSeekV3ModelCStruct
)]
lib
.
createDeepSeekV3Cache
.
argtypes
=
[
POINTER
(
DeepSeekV3ModelCStruct
)]
lib
.
createDeepSeekV3Cache
.
restype
=
POINTER
(
DeepSeekV3CacheCStruct
)
lib
.
dropDeepSeekV3Cache
.
argtypes
=
[
POINTER
(
DeepSeekV3ModelCStruct
),
POINTER
(
DeepSeekV3CacheCStruct
),
]
lib
.
inferBatchDeepSeekV3
.
argtypes
=
[
POINTER
(
DeepSeekV3ModelCStruct
),
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
POINTER
(
POINTER
(
DeepSeekV3CacheCStruct
)),
POINTER
(
c_float
),
POINTER
(
c_uint
),
POINTER
(
c_float
),
POINTER
(
c_uint
),
]
def
create_weight_loader
(
self
):
return
self
.
lib
.
createDeepSeekV3WeightLoader
()
def
create_weights
(
self
,
meta
,
device_type
,
ndev
,
dev_ids
):
return
self
.
lib
.
createDeepSeekV3Weights
(
meta
,
device_type
,
ndev
,
dev_ids
)
def
create_model
(
self
,
meta
,
weights
):
return
self
.
lib
.
createDeepSeekV3Model
(
meta
,
weights
)
def
destroy_model
(
self
,
model
):
self
.
lib
.
destroyDeepSeekV3Model
(
model
)
def
create_cache
(
self
,
model
):
return
self
.
lib
.
createDeepSeekV3Cache
(
model
)
def
drop_cache
(
self
,
model
,
cache
):
self
.
lib
.
dropDeepSeekV3Cache
(
model
,
cache
)
def
infer_batch
(
self
,
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
caches
,
temperature
,
topk
,
topp
,
output
,
):
self
.
lib
.
inferBatchDeepSeekV3
(
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
caches
,
temperature
,
topk
,
topp
,
output
,
)
scripts/libinfinicore_infer/jiuge.py
0 → 100644
View file @
3998658b
from
.base
import
BaseModel
,
DataType
,
DeviceType
,
KVCacheCStruct
,
register_model
from
ctypes
import
c_size_t
,
c_uint
,
c_int
,
c_float
,
c_void_p
,
POINTER
,
Structure
,
byref
class
JiugeMetaCStruct
(
Structure
):
_fields_
=
[
(
"dt_logits"
,
DataType
),
(
"nlayer"
,
c_size_t
),
(
"d"
,
c_size_t
),
(
"nh"
,
c_size_t
),
(
"nkvh"
,
c_size_t
),
(
"dh"
,
c_size_t
),
(
"di"
,
c_size_t
),
(
"dctx"
,
c_size_t
),
(
"dvoc"
,
c_size_t
),
(
"epsilon"
,
c_float
),
(
"theta"
,
c_float
),
(
"end_token"
,
c_uint
),
]
class
JiugeWeightsCStruct
(
Structure
):
_fields_
=
[
(
"nlayer"
,
c_size_t
),
(
"dt_norm"
,
DataType
),
(
"dt_mat"
,
DataType
),
(
"transpose_linear_weights"
,
c_int
),
(
"input_embd"
,
c_void_p
),
(
"output_norm"
,
c_void_p
),
(
"output_embd"
,
c_void_p
),
(
"attn_norm"
,
POINTER
(
c_void_p
)),
(
"attn_qkv"
,
POINTER
(
c_void_p
)),
(
"attn_qkv_b"
,
POINTER
(
c_void_p
)),
(
"attn_o"
,
POINTER
(
c_void_p
)),
(
"ffn_norm"
,
POINTER
(
c_void_p
)),
(
"ffn_gate_up"
,
POINTER
(
c_void_p
)),
(
"ffn_down"
,
POINTER
(
c_void_p
)),
]
class
JiugeModelCStruct
(
Structure
):
pass
@
register_model
class
JiugeModel
(
BaseModel
):
@
classmethod
def
register_lib
(
cls
,
lib
):
lib
.
createJiugeModel
.
restype
=
POINTER
(
JiugeModelCStruct
)
lib
.
createJiugeModel
.
argtypes
=
[
POINTER
(
JiugeMetaCStruct
),
POINTER
(
JiugeWeightsCStruct
),
DeviceType
,
c_int
,
POINTER
(
c_int
),
]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModelCStruct
)]
lib
.
createKVCache
.
argtypes
=
[
c_size_t
,
c_size_t
,
c_size_t
,
c_size_t
,
c_size_t
,
DataType
,
DeviceType
,
POINTER
(
c_int
),
c_size_t
,
]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCacheCStruct
)
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
KVCacheCStruct
)]
lib
.
inferBatchJiuge
.
argtypes
=
[
POINTER
(
JiugeModelCStruct
),
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
POINTER
(
POINTER
(
KVCacheCStruct
)),
POINTER
(
c_float
),
POINTER
(
c_uint
),
POINTER
(
c_float
),
POINTER
(
c_uint
),
]
lib
.
forwardBatchJiuge
.
argtypes
=
[
POINTER
(
JiugeModelCStruct
),
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
POINTER
(
POINTER
(
KVCacheCStruct
)),
c_void_p
,
]
def
create_model
(
self
,
meta
,
weights
,
device_type
,
ndev
,
dev_ids
):
return
self
.
lib
.
createJiugeModel
(
meta
,
weights
,
device_type
,
ndev
,
dev_ids
)
def
destroy_model
(
self
,
model
):
self
.
lib
.
destroyJiugeModel
(
model
)
def
create_kv_cache
(
self
,
nlayer
,
max_len
,
nkvh
,
dk
,
dv
,
dtype
,
device
,
dev_ids
,
ndev
):
return
self
.
lib
.
createKVCache
(
nlayer
,
max_len
,
nkvh
,
dk
,
dv
,
dtype
,
device
,
dev_ids
,
ndev
)
def
drop_kv_cache
(
self
,
kv_cache
):
self
.
lib
.
dropKVCache
(
kv_cache
)
def
infer_batch
(
self
,
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
temperature
,
topk
,
topp
,
output
,
):
self
.
lib
.
inferBatchJiuge
(
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
temperature
,
topk
,
topp
,
output
,
)
def
forward_batch
(
self
,
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
logits
):
self
.
lib
.
forwardBatchJiuge
(
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
logits
)
scripts/libinfinicore_infer/jiuge_awq.py
0 → 100644
View file @
3998658b
from
.base
import
BaseModel
,
DataType
,
DeviceType
,
KVCacheCStruct
,
register_model
from
ctypes
import
(
c_size_t
,
c_uint
,
c_int
,
c_float
,
c_void_p
,
POINTER
,
Structure
,
c_char
,
c_char_p
,
)
class
JiugeAWQMetaCStruct
(
Structure
):
_fields_
=
[
(
"dt_logits"
,
DataType
),
(
"dt_linear_w"
,
DataType
),
(
"dt_norm_w"
,
DataType
),
(
"nlayer"
,
c_size_t
),
(
"d"
,
c_size_t
),
(
"nh"
,
c_size_t
),
(
"nkvh"
,
c_size_t
),
(
"dh"
,
c_size_t
),
(
"di"
,
c_size_t
),
(
"dctx"
,
c_size_t
),
(
"dvoc"
,
c_size_t
),
(
"epsilon"
,
c_float
),
(
"theta"
,
c_float
),
(
"end_token"
,
c_uint
),
(
"nbit"
,
c_size_t
),
(
"quant_group_size"
,
c_size_t
),
(
"has_qkv_bias"
,
c_char
),
]
class
ModelWeightsCStruct
(
Structure
):
pass
class
JiugeAWQModelCStruct
(
Structure
):
pass
@
register_model
class
JiugeAWQModel
(
BaseModel
):
@
classmethod
def
register_lib
(
cls
,
lib
):
"""Register JiugeAWQ model functions with the library"""
lib
.
createJiugeAWQWeights
.
restype
=
POINTER
(
ModelWeightsCStruct
)
lib
.
createJiugeAWQWeights
.
argtypes
=
[
POINTER
(
JiugeAWQMetaCStruct
),
DeviceType
,
c_int
,
POINTER
(
c_int
),
]
lib
.
createJiugeAWQModel
.
restype
=
POINTER
(
JiugeAWQModelCStruct
)
lib
.
createJiugeAWQModel
.
argtypes
=
[
POINTER
(
JiugeAWQMetaCStruct
),
POINTER
(
ModelWeightsCStruct
),
]
lib
.
destroyJiugeAWQModel
.
argtypes
=
[
POINTER
(
JiugeAWQModelCStruct
)]
lib
.
createKVCache
.
argtypes
=
[
c_size_t
,
c_size_t
,
c_size_t
,
c_size_t
,
c_size_t
,
DataType
,
DeviceType
,
POINTER
(
c_int
),
c_size_t
,
]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCacheCStruct
)
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
KVCacheCStruct
)]
lib
.
inferBatchJiugeAWQ
.
argtypes
=
[
POINTER
(
JiugeAWQModelCStruct
),
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
POINTER
(
POINTER
(
KVCacheCStruct
)),
POINTER
(
c_float
),
POINTER
(
c_uint
),
POINTER
(
c_float
),
POINTER
(
c_uint
),
]
lib
.
forwardBatchJiugeAWQ
.
argtypes
=
[
POINTER
(
JiugeAWQModelCStruct
),
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
c_uint
,
POINTER
(
c_uint
),
POINTER
(
POINTER
(
KVCacheCStruct
)),
c_void_p
,
]
lib
.
loadModelWeight
.
argtypes
=
[
POINTER
(
ModelWeightsCStruct
),
c_char_p
,
c_void_p
,
]
def
create_weights
(
self
,
meta
,
device_type
,
ndev
,
dev_ids
):
return
self
.
lib
.
createJiugeAWQWeights
(
meta
,
device_type
,
ndev
,
dev_ids
)
def
create_model
(
self
,
meta
,
weights
):
return
self
.
lib
.
createJiugeAWQModel
(
meta
,
weights
)
def
destroy_model
(
self
,
model
):
self
.
lib
.
destroyJiugeAWQModel
(
model
)
def
create_kv_cache
(
self
,
nlayer
,
max_len
,
nkvh
,
dk
,
dv
,
dtype
,
device
,
dev_ids
,
ndev
):
return
self
.
lib
.
createKVCache
(
nlayer
,
max_len
,
nkvh
,
dk
,
dv
,
dtype
,
device
,
dev_ids
,
ndev
)
def
drop_kv_cache
(
self
,
kv_cache
):
self
.
lib
.
dropKVCache
(
kv_cache
)
def
load_weight
(
self
,
weights
,
name
,
data
):
self
.
lib
.
loadModelWeight
(
weights
,
name
.
encode
(
"utf-8"
),
data
)
def
infer_batch
(
self
,
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
temperature
,
topk
,
topp
,
output
,
):
self
.
lib
.
inferBatchJiugeAWQ
(
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
temperature
,
topk
,
topp
,
output
,
)
def
forward_batch
(
self
,
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
logits
):
self
.
lib
.
forwardBatchJiugeAWQ
(
model
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment