Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
2b9ce5a6
Commit
2b9ce5a6
authored
Jun 25, 2025
by
Pan Zezhong
Browse files
refactor task to avoid model dependency on async libs
parent
cfc8b598
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
171 additions
and
122 deletions
+171
-122
scripts/infer_task.py
scripts/infer_task.py
+33
-12
scripts/jiuge.py
scripts/jiuge.py
+105
-69
scripts/kvcache_pool.py
scripts/kvcache_pool.py
+8
-27
scripts/launch_server.py
scripts/launch_server.py
+12
-1
scripts/libinfinicore_infer.py
scripts/libinfinicore_infer.py
+13
-13
No files found.
scripts/infer_task.py
View file @
2b9ce5a6
import
janus
class
InferTask
:
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
self
.
id
=
id
...
...
@@ -11,21 +8,24 @@ class InferTask:
self
.
topk
=
topk
self
.
topp
=
topp
self
.
end_tokens
=
end_tokens
self
.
output_queue
=
janus
.
Queue
()
self
.
_kv_cache_pool_item
=
None
self
.
_kv_cache
=
None
self
.
pos
=
0
print
(
f
"[INFO] Create InferTask
{
self
.
id
}
"
)
def
bind_kvcache
(
self
,
kv_cache
_pool_item
,
pos
):
self
.
_kv_cache
_pool_item
=
kv_cache
_pool_item
def
bind_kvcache
(
self
,
kv_cache
,
pos
=
0
):
self
.
_kv_cache
=
kv_cache
self
.
pos
=
pos
self
.
tokens
=
self
.
tokens
[
pos
:]
def
release_kvcache
(
self
):
cache
=
self
.
_kv_cache
self
.
_kv_cache
=
None
return
cache
def
kvcache
(
self
):
return
self
.
_kv_cache
_pool_item
.
kvcache
return
self
.
_kv_cache
def
outpu
t
(
self
,
out_token
):
self
.
_kv_cache
_pool_item
.
update_tokens
(
self
.
tokens
,
self
.
pos
)
def
nex
t
(
self
,
out_token
):
self
.
_kv_cache
.
update_tokens
(
self
.
tokens
,
self
.
pos
)
self
.
pos
+=
len
(
self
.
tokens
)
if
out_token
==
None
or
out_token
in
self
.
end_tokens
:
...
...
@@ -35,4 +35,25 @@ class InferTask:
else
:
self
.
tokens
=
[
out_token
]
self
.
output_queue
.
sync_q
.
put
(
out_token
)
class
KVCache
:
def
__init__
(
self
,
model
):
self
.
_kvcache
=
model
.
create_kv_cache
()
self
.
tokens
=
[
0
for
_
in
range
(
model
.
max_context_len
())]
def
data
(
self
):
return
self
.
_kvcache
def
drop
(
self
,
model
):
model
.
drop_kv_cache
(
self
.
_kvcache
)
def
update_tokens
(
self
,
tokens
,
pos
):
end
=
pos
+
len
(
tokens
)
max_len
=
len
(
self
.
tokens
)
# If overflow, truncate tokens to fit
if
end
>
max_len
:
tokens
=
tokens
[:
max_len
-
pos
]
end
=
max_len
self
.
tokens
[
pos
:
end
]
=
tokens
scripts/jiuge.py
View file @
2b9ce5a6
from
typing
import
List
from
libinfinicore_infer
import
(
JiugeMeta
,
JiugeWeights
,
KVCache
,
JiugeMeta
CStruct
,
JiugeWeights
CStruct
,
KVCache
CStruct
,
DataType
,
DeviceType
,
create_jiuge_model
,
...
...
@@ -11,7 +11,7 @@ from libinfinicore_infer import (
drop_kv_cache
,
infer_batch
,
)
from
infer_task
import
InferTask
from
infer_task
import
InferTask
,
KVCache
from
ctypes
import
POINTER
,
c_float
,
c_int
,
c_uint
,
c_void_p
,
byref
import
os
...
...
@@ -25,6 +25,7 @@ import transformers
torch
.
set_default_device
(
"cpu"
)
class
LlamaWeightsNaming
:
def
input_embd
(
self
):
return
"model.embed_tokens.weight"
...
...
@@ -78,7 +79,7 @@ class LlamaWeightsNaming:
)
class
JiugeMetaFromLlama
(
JiugeMeta
):
class
JiugeMetaFromLlama
(
JiugeMeta
CStruct
):
def
__init__
(
self
,
config
,
dtype
=
torch
.
float16
):
if
dtype
==
torch
.
float16
:
dt_
=
DataType
.
INFINI_DTYPE_F16
...
...
@@ -107,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta):
self
.
torch_dtype_logits
=
dtype
class
JiugeWeightsImpl
(
JiugeWeights
):
class
JiugeWeightsImpl
(
JiugeWeights
CStruct
):
def
__init__
(
self
,
meta
,
...
...
@@ -160,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_embd_tensor
=
state_dict
[
output_embd_naming
].
to
(
torch_dt_mat
)
if
not
transpose_weight
:
self
.
output_embd_tensor
=
self
.
output_embd_tensor
.
transpose
(
0
,
1
).
contiguous
()
self
.
output_embd_tensor
=
self
.
output_embd_tensor
.
transpose
(
0
,
1
).
contiguous
()
self
.
output_embd
=
self
.
output_embd_tensor
.
data_ptr
()
self
.
attn_norm_tensors
=
[
...
...
@@ -197,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights):
]
if
not
transpose_weight
:
for
i
in
range
(
nlayer
):
self
.
qkv_tensor
[
i
]
=
self
.
qkv_tensor
[
i
].
reshape
(
ndev
,
(
nh
+
2
*
nkvh
)
//
ndev
*
dh
,
d
).
transpose
(
1
,
2
).
contiguous
()
self
.
qkv_tensor
[
i
]
=
(
self
.
qkv_tensor
[
i
]
.
reshape
(
ndev
,
(
nh
+
2
*
nkvh
)
//
ndev
*
dh
,
d
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
self
.
qkv_tensor_ptrs
=
[
self
.
qkv_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_qkv
=
(
c_void_p
*
nlayer
)(
*
self
.
qkv_tensor_ptrs
)
...
...
@@ -234,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
attn_qkv_b
=
None
self
.
attn_o_tensor
=
[
(
state_dict
[
naming
.
attn_o
(
i
)]
.
to
(
torch_dt_mat
)
.
reshape
([
d
,
ndev
,
nh
//
ndev
*
dh
])
.
transpose
(
0
,
1
)
.
contiguous
()
if
transpose_weight
else
state_dict
[
naming
.
attn_o
(
i
)].
transpose
(
0
,
1
).
to
(
torch_dt_mat
).
contiguous
()
else
state_dict
[
naming
.
attn_o
(
i
)]
.
transpose
(
0
,
1
)
.
to
(
torch_dt_mat
)
.
contiguous
()
)
for
i
in
range
(
nlayer
)
]
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
...
...
@@ -269,18 +282,28 @@ class JiugeWeightsImpl(JiugeWeights):
]
if
not
transpose_weight
:
for
i
in
range
(
nlayer
):
self
.
gate_up_tensors
[
i
]
=
self
.
gate_up_tensors
[
i
].
reshape
(
ndev
,
2
*
di
//
ndev
,
d
).
transpose
(
1
,
2
).
contiguous
()
self
.
gate_up_tensors
[
i
]
=
(
self
.
gate_up_tensors
[
i
]
.
reshape
(
ndev
,
2
*
di
//
ndev
,
d
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
self
.
gate_up_ptrs
=
[
self
.
gate_up_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
ffn_gate_up
=
(
c_void_p
*
nlayer
)(
*
self
.
gate_up_ptrs
)
self
.
ffn_down_tensor
=
[
(
state_dict
[
naming
.
down
(
i
)]
.
to
(
torch_dt_mat
)
.
reshape
([
d
,
ndev
,
di
//
ndev
])
.
transpose
(
0
,
1
)
.
contiguous
()
if
transpose_weight
else
state_dict
[
naming
.
down
(
i
)].
transpose
(
0
,
1
).
to
(
torch_dt_mat
).
contiguous
()
else
state_dict
[
naming
.
down
(
i
)]
.
transpose
(
0
,
1
)
.
to
(
torch_dt_mat
)
.
contiguous
()
)
for
i
in
range
(
nlayer
)
]
self
.
ffn_down_ptrs
=
[
self
.
ffn_down_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
...
...
@@ -296,7 +319,7 @@ class JiugeBatchedTask:
token_lists
=
[
t
.
tokens
for
t
in
tasks
]
self
.
req_lens_list
=
[
len
(
toks
)
for
toks
in
token_lists
]
self
.
req_pos_list
=
[
t
.
pos
for
t
in
tasks
]
self
.
kv_cache_ptrs
=
[
t
.
kvcache
()
for
t
in
tasks
]
self
.
kv_cache_ptrs
=
[
t
.
kvcache
()
.
data
()
for
t
in
tasks
]
self
.
temperaturas_list
=
[
t
.
temperature
for
t
in
tasks
]
self
.
topks_list
=
[
t
.
topk
for
t
in
tasks
]
self
.
topps_list
=
[
t
.
topp
for
t
in
tasks
]
...
...
@@ -309,7 +332,7 @@ class JiugeBatchedTask:
self
.
tokens
=
(
c_uint
*
self
.
ntok
)(
*
flat_tokens
)
self
.
req_lens
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_lens_list
)
self
.
req_pos
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_pos_list
)
self
.
kv_caches
=
(
POINTER
(
KVCache
)
*
self
.
nreq
)(
*
self
.
kv_cache_ptrs
)
self
.
kv_caches
=
(
POINTER
(
KVCache
CStruct
)
*
self
.
nreq
)(
*
self
.
kv_cache_ptrs
)
self
.
temperaturas
=
(
c_float
*
self
.
nreq
)(
*
self
.
temperaturas_list
)
self
.
topks
=
(
c_uint
*
self
.
nreq
)(
*
self
.
topks_list
)
self
.
topps
=
(
c_float
*
self
.
nreq
)(
*
self
.
topps_list
)
...
...
@@ -346,26 +369,46 @@ class JiugeForCauslLM:
config
=
json
.
load
(
f
)
self
.
config
=
config
eos_token_id
=
self
.
config
[
"eos_token_id"
]
self
.
eos_token_id
=
[
eos_token_id
]
if
type
(
eos_token_id
)
==
int
else
eos_token_id
transpose_weight
=
device
!=
DeviceType
.
DEVICE_TYPE_ASCEND
# y = xW is faster than y=xW^T on Ascend
self
.
eos_token_id
=
(
[
eos_token_id
]
if
type
(
eos_token_id
)
==
int
else
eos_token_id
)
transpose_weight
=
(
device
!=
DeviceType
.
DEVICE_TYPE_ASCEND
)
# y = xW is faster than y=xW^T on Ascend
if
"llama"
==
config
[
"model_type"
]:
model
=
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
).
cpu
().
half
()
model
=
(
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
)
.
cpu
()
.
half
()
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
elif
"fm9g"
==
config
[
"model_type"
]:
if
any
(
file
.
suffix
==
".safetensors"
for
file
in
Path
(
model_dir_path
).
iterdir
()):
if
any
(
file
.
suffix
==
".safetensors"
for
file
in
Path
(
model_dir_path
).
iterdir
()
):
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
else
:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
,
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
...
...
@@ -374,12 +417,18 @@ class JiugeForCauslLM:
raise
ValueError
(
"Unsupported weight naming"
)
elif
"fm9g7b"
==
config
[
"model_type"
]:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
,
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
...
...
@@ -391,7 +440,11 @@ class JiugeForCauslLM:
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
...
...
@@ -435,7 +488,6 @@ class JiugeForCauslLM:
return
list
(
output
)
def
generate
(
self
,
input_content
,
max_steps
,
topp_
=
1.0
,
topk_
=
1
,
temperature_
=
1.0
):
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
add_generation_prompt
=
True
,
...
...
@@ -443,39 +495,26 @@ class JiugeForCauslLM:
)
print
(
input_content
,
end
=
""
,
flush
=
True
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
ntok
=
len
(
tokens
)
nreq
=
1
output_content
=
""
tokens
=
(
c_uint
*
ntok
)(
*
tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
ans
=
(
c_uint
*
nreq
)()
temperature
=
(
c_float
*
nreq
)(
*
[
temperature_
])
topk
=
(
c_uint
*
nreq
)(
*
[
topk_
])
topp
=
(
c_float
*
nreq
)(
*
[
topp_
])
infer_task
=
InferTask
(
0
,
tokens
,
self
.
max_context_len
(),
temperature_
,
topk_
,
topp_
,
self
.
eos_token_id
,
)
infer_task
.
bind_kvcache
(
KVCache
(
self
))
steps
=
0
total_time
=
0
output_content
=
""
for
step_i
in
range
(
max_steps
):
start_time
=
time
.
time
()
infer_batch
(
self
.
model_instance
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
temperature
,
topk
,
topp
,
ans
,
)
steps
+=
1
output_tokens
=
list
(
ans
)
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
end_time
=
time
.
time
()
steps
+=
1
output_str
=
(
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
...
...
@@ -485,10 +524,7 @@ class JiugeForCauslLM:
print
(
output_str
,
end
=
""
,
flush
=
True
)
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
break
req_pos
[
0
]
=
req_pos
[
0
]
+
ntok
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
infer_task
.
next
(
output_tokens
[
0
])
if
step_i
>
0
:
total_time
+=
end_time
-
start_time
...
...
@@ -496,8 +532,8 @@ class JiugeForCauslLM:
print
(
"
\n
"
)
avg_time
=
total_time
*
1000
/
(
steps
-
1
)
print
(
f
"Time per step:
{
avg_time
:.
3
f
}
ms"
)
for
kv_cache
in
kv_caches
:
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
infer_task
.
_kv_cache
.
drop
(
self
)
return
output_content
,
avg_time
def
destroy_model_instance
(
self
):
...
...
scripts/kvcache_pool.py
View file @
2b9ce5a6
from
infer_task
import
KVCache
import
asyncio
from
typing
import
List
class
KVCachePoolItem
:
def
__init__
(
self
,
model
):
self
.
kvcache
=
model
.
create_kv_cache
()
self
.
tokens
=
[
0
for
_
in
range
(
model
.
max_context_len
())]
def
drop
(
self
,
model
):
model
.
drop_kv_cache
(
self
.
kvcache
)
def
update_tokens
(
self
,
tokens
,
pos
):
end
=
pos
+
len
(
tokens
)
max_len
=
len
(
self
.
tokens
)
# If overflow, truncate tokens to fit
if
end
>
max_len
:
tokens
=
tokens
[:
max_len
-
pos
]
end
=
max_len
self
.
tokens
[
pos
:
end
]
=
tokens
import
threading
...
...
@@ -29,7 +9,7 @@ class KVCachePool:
def
__init__
(
self
,
model
,
max_caches
:
int
=
32
):
self
.
max_caches
=
max_caches
self
.
model
=
model
self
.
_available
:
List
[
KVCache
PoolItem
]
=
[]
self
.
_available
:
List
[
KVCache
]
=
[]
self
.
num_caches
=
len
(
self
.
_available
)
self
.
_lock
=
threading
.
Lock
()
self
.
_not_empty
=
threading
.
Condition
(
self
.
_lock
)
...
...
@@ -45,8 +25,10 @@ class KVCachePool:
if
len
(
self
.
_available
)
==
0
:
if
self
.
num_caches
<
self
.
max_caches
:
self
.
num_caches
+=
1
print
(
f
"[INFO] Task
{
infer_task
.
id
}
created new KVCachePoolItem"
)
return
infer_task
.
bind_kvcache
(
KVCachePoolItem
(
self
.
model
),
0
)
print
(
f
"[INFO] Task
{
infer_task
.
id
}
created new KVCachePoolItem"
)
return
infer_task
.
bind_kvcache
(
KVCache
(
self
.
model
),
0
)
else
:
self
.
_not_empty
.
wait
()
else
:
...
...
@@ -62,8 +44,7 @@ class KVCachePool:
def
release_sync
(
self
,
infer_task
):
with
self
.
_not_empty
:
print
(
f
"[INFO] Task
{
infer_task
.
id
}
returned KVCachePoolItem to pool"
)
self
.
_available
.
append
(
infer_task
.
_kv_cache_pool_item
)
infer_task
.
_kv_cache_pool_item
=
None
self
.
_available
.
append
(
infer_task
.
release_kvcache
())
self
.
_not_empty
.
notify
()
async
def
acquire
(
self
,
infer_task
):
...
...
scripts/launch_server.py
View file @
2b9ce5a6
...
...
@@ -70,6 +70,17 @@ print(
f
"Using MAX_BATCH=
{
MAX_BATCH
}
. Try reduce this value if out of memory error occurs."
)
# A wrapper for InferTask that supports async output queue
class
AsyncInferTask
(
InferTask
):
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
super
().
__init__
(
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
)
self
.
output_queue
=
janus
.
Queue
()
print
(
f
"[INFO] Create InferTask
{
self
.
id
}
"
)
def
output
(
self
,
out_token
):
self
.
next
(
out_token
)
self
.
output_queue
.
sync_q
.
put
(
out_token
)
@
contextlib
.
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
...
...
@@ -132,7 +143,7 @@ def build_task(id_, request_data, request: Request):
tokenize
=
False
,
)
tokens
=
request
.
app
.
state
.
model
.
tokenizer
.
encode
(
input_content
)
return
InferTask
(
return
Async
InferTask
(
id_
,
tokens
,
request_data
.
get
(
"max_tokens"
,
request
.
app
.
state
.
model
.
max_context_len
()),
...
...
scripts/libinfinicore_infer.py
View file @
2b9ce5a6
...
...
@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int):
DEVICE_TYPE_MOORE
=
5
class
JiugeMeta
(
ctypes
.
Structure
):
class
JiugeMeta
CStruct
(
ctypes
.
Structure
):
_fields_
=
[
(
"dt_logits"
,
DataType
),
(
"nlayer"
,
c_size_t
),
...
...
@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure):
# Define the JiugeWeights struct
class
JiugeWeights
(
ctypes
.
Structure
):
class
JiugeWeights
CStruct
(
ctypes
.
Structure
):
_fields_
=
[
(
"nlayer"
,
c_size_t
),
(
"dt_norm"
,
DataType
),
...
...
@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure):
]
class
JiugeModel
(
ctypes
.
Structure
):
class
JiugeModel
CSruct
(
ctypes
.
Structure
):
pass
class
KVCache
(
ctypes
.
Structure
):
class
KVCache
CStruct
(
ctypes
.
Structure
):
pass
...
...
@@ -85,27 +85,27 @@ def __open_library__():
os
.
environ
.
get
(
"INFINI_ROOT"
),
"lib"
,
"libinfinicore_infer.so"
)
lib
=
ctypes
.
CDLL
(
lib_path
)
lib
.
createJiugeModel
.
restype
=
POINTER
(
JiugeModel
)
lib
.
createJiugeModel
.
restype
=
POINTER
(
JiugeModel
CSruct
)
lib
.
createJiugeModel
.
argtypes
=
[
POINTER
(
JiugeMeta
),
# JiugeMeta const *
POINTER
(
JiugeWeights
),
# JiugeWeights const *
POINTER
(
JiugeMeta
CStruct
),
# JiugeMeta const *
POINTER
(
JiugeWeights
CStruct
),
# JiugeWeights const *
DeviceType
,
# DeviceType
c_int
,
# int ndev
POINTER
(
c_int
),
# int const *dev_ids
]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
createKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
)
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
),
POINTER
(
KVCache
)]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
CSruct
)]
lib
.
createKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
CSruct
)]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
CStruct
)
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
CSruct
),
POINTER
(
KVCache
CStruct
)]
lib
.
inferBatch
.
restype
=
None
lib
.
inferBatch
.
argtypes
=
[
POINTER
(
JiugeModel
),
# struct JiugeModel const *
POINTER
(
JiugeModel
CSruct
),
# struct JiugeModel const *
POINTER
(
c_uint
),
# unsigned int const *tokens
c_uint
,
# unsigned int ntok
POINTER
(
c_uint
),
# unsigned int const *req_lens
c_uint
,
# unsigned int nreq
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
POINTER
(
KVCache
)),
# struct KVCache **kv_caches
POINTER
(
POINTER
(
KVCache
CStruct
)),
# struct KVCache **kv_caches
POINTER
(
c_float
),
# float temperature
POINTER
(
c_uint
),
# unsigned int topk
POINTER
(
c_float
),
# float topp
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment