Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
2b9ce5a6
Commit
2b9ce5a6
authored
Jun 25, 2025
by
Pan Zezhong
Browse files
refactor task to avoid model dependency on async libs
parent
cfc8b598
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
171 additions
and
122 deletions
+171
-122
scripts/infer_task.py
scripts/infer_task.py
+33
-12
scripts/jiuge.py
scripts/jiuge.py
+105
-69
scripts/kvcache_pool.py
scripts/kvcache_pool.py
+8
-27
scripts/launch_server.py
scripts/launch_server.py
+12
-1
scripts/libinfinicore_infer.py
scripts/libinfinicore_infer.py
+13
-13
No files found.
scripts/infer_task.py
View file @
2b9ce5a6
import
janus
class
InferTask
:
class
InferTask
:
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
self
.
id
=
id
self
.
id
=
id
...
@@ -11,21 +8,24 @@ class InferTask:
...
@@ -11,21 +8,24 @@ class InferTask:
self
.
topk
=
topk
self
.
topk
=
topk
self
.
topp
=
topp
self
.
topp
=
topp
self
.
end_tokens
=
end_tokens
self
.
end_tokens
=
end_tokens
self
.
output_queue
=
janus
.
Queue
()
self
.
_kv_cache
=
None
self
.
_kv_cache_pool_item
=
None
self
.
pos
=
0
self
.
pos
=
0
print
(
f
"[INFO] Create InferTask
{
self
.
id
}
"
)
def
bind_kvcache
(
self
,
kv_cache
_pool_item
,
pos
):
def
bind_kvcache
(
self
,
kv_cache
,
pos
=
0
):
self
.
_kv_cache
_pool_item
=
kv_cache
_pool_item
self
.
_kv_cache
=
kv_cache
self
.
pos
=
pos
self
.
pos
=
pos
self
.
tokens
=
self
.
tokens
[
pos
:]
self
.
tokens
=
self
.
tokens
[
pos
:]
def
release_kvcache
(
self
):
cache
=
self
.
_kv_cache
self
.
_kv_cache
=
None
return
cache
def
kvcache
(
self
):
def
kvcache
(
self
):
return
self
.
_kv_cache
_pool_item
.
kvcache
return
self
.
_kv_cache
def
outpu
t
(
self
,
out_token
):
def
nex
t
(
self
,
out_token
):
self
.
_kv_cache
_pool_item
.
update_tokens
(
self
.
tokens
,
self
.
pos
)
self
.
_kv_cache
.
update_tokens
(
self
.
tokens
,
self
.
pos
)
self
.
pos
+=
len
(
self
.
tokens
)
self
.
pos
+=
len
(
self
.
tokens
)
if
out_token
==
None
or
out_token
in
self
.
end_tokens
:
if
out_token
==
None
or
out_token
in
self
.
end_tokens
:
...
@@ -35,4 +35,25 @@ class InferTask:
...
@@ -35,4 +35,25 @@ class InferTask:
else
:
else
:
self
.
tokens
=
[
out_token
]
self
.
tokens
=
[
out_token
]
self
.
output_queue
.
sync_q
.
put
(
out_token
)
class
KVCache
:
def
__init__
(
self
,
model
):
self
.
_kvcache
=
model
.
create_kv_cache
()
self
.
tokens
=
[
0
for
_
in
range
(
model
.
max_context_len
())]
def
data
(
self
):
return
self
.
_kvcache
def
drop
(
self
,
model
):
model
.
drop_kv_cache
(
self
.
_kvcache
)
def
update_tokens
(
self
,
tokens
,
pos
):
end
=
pos
+
len
(
tokens
)
max_len
=
len
(
self
.
tokens
)
# If overflow, truncate tokens to fit
if
end
>
max_len
:
tokens
=
tokens
[:
max_len
-
pos
]
end
=
max_len
self
.
tokens
[
pos
:
end
]
=
tokens
scripts/jiuge.py
View file @
2b9ce5a6
from
typing
import
List
from
typing
import
List
from
libinfinicore_infer
import
(
from
libinfinicore_infer
import
(
JiugeMeta
,
JiugeMeta
CStruct
,
JiugeWeights
,
JiugeWeights
CStruct
,
KVCache
,
KVCache
CStruct
,
DataType
,
DataType
,
DeviceType
,
DeviceType
,
create_jiuge_model
,
create_jiuge_model
,
...
@@ -11,7 +11,7 @@ from libinfinicore_infer import (
...
@@ -11,7 +11,7 @@ from libinfinicore_infer import (
drop_kv_cache
,
drop_kv_cache
,
infer_batch
,
infer_batch
,
)
)
from
infer_task
import
InferTask
from
infer_task
import
InferTask
,
KVCache
from
ctypes
import
POINTER
,
c_float
,
c_int
,
c_uint
,
c_void_p
,
byref
from
ctypes
import
POINTER
,
c_float
,
c_int
,
c_uint
,
c_void_p
,
byref
import
os
import
os
...
@@ -25,6 +25,7 @@ import transformers
...
@@ -25,6 +25,7 @@ import transformers
torch
.
set_default_device
(
"cpu"
)
torch
.
set_default_device
(
"cpu"
)
class
LlamaWeightsNaming
:
class
LlamaWeightsNaming
:
def
input_embd
(
self
):
def
input_embd
(
self
):
return
"model.embed_tokens.weight"
return
"model.embed_tokens.weight"
...
@@ -78,7 +79,7 @@ class LlamaWeightsNaming:
...
@@ -78,7 +79,7 @@ class LlamaWeightsNaming:
)
)
class
JiugeMetaFromLlama
(
JiugeMeta
):
class
JiugeMetaFromLlama
(
JiugeMeta
CStruct
):
def
__init__
(
self
,
config
,
dtype
=
torch
.
float16
):
def
__init__
(
self
,
config
,
dtype
=
torch
.
float16
):
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
dt_
=
DataType
.
INFINI_DTYPE_F16
dt_
=
DataType
.
INFINI_DTYPE_F16
...
@@ -107,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta):
...
@@ -107,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta):
self
.
torch_dtype_logits
=
dtype
self
.
torch_dtype_logits
=
dtype
class
JiugeWeightsImpl
(
JiugeWeights
):
class
JiugeWeightsImpl
(
JiugeWeights
CStruct
):
def
__init__
(
def
__init__
(
self
,
self
,
meta
,
meta
,
...
@@ -160,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -160,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_embd_tensor
=
state_dict
[
output_embd_naming
].
to
(
torch_dt_mat
)
self
.
output_embd_tensor
=
state_dict
[
output_embd_naming
].
to
(
torch_dt_mat
)
if
not
transpose_weight
:
if
not
transpose_weight
:
self
.
output_embd_tensor
=
self
.
output_embd_tensor
.
transpose
(
0
,
1
).
contiguous
()
self
.
output_embd_tensor
=
self
.
output_embd_tensor
.
transpose
(
0
,
1
).
contiguous
()
self
.
output_embd
=
self
.
output_embd_tensor
.
data_ptr
()
self
.
output_embd
=
self
.
output_embd_tensor
.
data_ptr
()
self
.
attn_norm_tensors
=
[
self
.
attn_norm_tensors
=
[
...
@@ -197,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -197,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights):
]
]
if
not
transpose_weight
:
if
not
transpose_weight
:
for
i
in
range
(
nlayer
):
for
i
in
range
(
nlayer
):
self
.
qkv_tensor
[
i
]
=
self
.
qkv_tensor
[
i
].
reshape
(
ndev
,
(
nh
+
2
*
nkvh
)
//
ndev
*
dh
,
d
).
transpose
(
1
,
2
).
contiguous
()
self
.
qkv_tensor
[
i
]
=
(
self
.
qkv_tensor
[
i
]
.
reshape
(
ndev
,
(
nh
+
2
*
nkvh
)
//
ndev
*
dh
,
d
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
self
.
qkv_tensor_ptrs
=
[
self
.
qkv_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
qkv_tensor_ptrs
=
[
self
.
qkv_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_qkv
=
(
c_void_p
*
nlayer
)(
*
self
.
qkv_tensor_ptrs
)
self
.
attn_qkv
=
(
c_void_p
*
nlayer
)(
*
self
.
qkv_tensor_ptrs
)
...
@@ -234,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -234,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
attn_qkv_b
=
None
self
.
attn_qkv_b
=
None
self
.
attn_o_tensor
=
[
self
.
attn_o_tensor
=
[
state_dict
[
naming
.
attn_o
(
i
)]
(
state_dict
[
naming
.
attn_o
(
i
)]
.
to
(
torch_dt_mat
)
.
to
(
torch_dt_mat
)
.
reshape
([
d
,
ndev
,
nh
//
ndev
*
dh
])
.
reshape
([
d
,
ndev
,
nh
//
ndev
*
dh
])
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
contiguous
()
if
transpose_weight
if
transpose_weight
else
state_dict
[
naming
.
attn_o
(
i
)].
transpose
(
0
,
1
).
to
(
torch_dt_mat
).
contiguous
()
else
state_dict
[
naming
.
attn_o
(
i
)]
.
transpose
(
0
,
1
)
.
to
(
torch_dt_mat
)
.
contiguous
()
)
for
i
in
range
(
nlayer
)
for
i
in
range
(
nlayer
)
]
]
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
...
@@ -269,18 +282,28 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -269,18 +282,28 @@ class JiugeWeightsImpl(JiugeWeights):
]
]
if
not
transpose_weight
:
if
not
transpose_weight
:
for
i
in
range
(
nlayer
):
for
i
in
range
(
nlayer
):
self
.
gate_up_tensors
[
i
]
=
self
.
gate_up_tensors
[
i
].
reshape
(
ndev
,
2
*
di
//
ndev
,
d
).
transpose
(
1
,
2
).
contiguous
()
self
.
gate_up_tensors
[
i
]
=
(
self
.
gate_up_tensors
[
i
]
.
reshape
(
ndev
,
2
*
di
//
ndev
,
d
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
self
.
gate_up_ptrs
=
[
self
.
gate_up_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
gate_up_ptrs
=
[
self
.
gate_up_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
ffn_gate_up
=
(
c_void_p
*
nlayer
)(
*
self
.
gate_up_ptrs
)
self
.
ffn_gate_up
=
(
c_void_p
*
nlayer
)(
*
self
.
gate_up_ptrs
)
self
.
ffn_down_tensor
=
[
self
.
ffn_down_tensor
=
[
state_dict
[
naming
.
down
(
i
)]
(
.
to
(
torch_dt_mat
)
state_dict
[
naming
.
down
(
i
)]
.
reshape
([
d
,
ndev
,
di
//
ndev
])
.
to
(
torch_dt_mat
)
.
transpose
(
0
,
1
)
.
reshape
([
d
,
ndev
,
di
//
ndev
])
.
contiguous
()
.
transpose
(
0
,
1
)
if
transpose_weight
.
contiguous
()
else
state_dict
[
naming
.
down
(
i
)].
transpose
(
0
,
1
).
to
(
torch_dt_mat
).
contiguous
()
if
transpose_weight
else
state_dict
[
naming
.
down
(
i
)]
.
transpose
(
0
,
1
)
.
to
(
torch_dt_mat
)
.
contiguous
()
)
for
i
in
range
(
nlayer
)
for
i
in
range
(
nlayer
)
]
]
self
.
ffn_down_ptrs
=
[
self
.
ffn_down_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
ffn_down_ptrs
=
[
self
.
ffn_down_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
...
@@ -296,7 +319,7 @@ class JiugeBatchedTask:
...
@@ -296,7 +319,7 @@ class JiugeBatchedTask:
token_lists
=
[
t
.
tokens
for
t
in
tasks
]
token_lists
=
[
t
.
tokens
for
t
in
tasks
]
self
.
req_lens_list
=
[
len
(
toks
)
for
toks
in
token_lists
]
self
.
req_lens_list
=
[
len
(
toks
)
for
toks
in
token_lists
]
self
.
req_pos_list
=
[
t
.
pos
for
t
in
tasks
]
self
.
req_pos_list
=
[
t
.
pos
for
t
in
tasks
]
self
.
kv_cache_ptrs
=
[
t
.
kvcache
()
for
t
in
tasks
]
self
.
kv_cache_ptrs
=
[
t
.
kvcache
()
.
data
()
for
t
in
tasks
]
self
.
temperaturas_list
=
[
t
.
temperature
for
t
in
tasks
]
self
.
temperaturas_list
=
[
t
.
temperature
for
t
in
tasks
]
self
.
topks_list
=
[
t
.
topk
for
t
in
tasks
]
self
.
topks_list
=
[
t
.
topk
for
t
in
tasks
]
self
.
topps_list
=
[
t
.
topp
for
t
in
tasks
]
self
.
topps_list
=
[
t
.
topp
for
t
in
tasks
]
...
@@ -309,7 +332,7 @@ class JiugeBatchedTask:
...
@@ -309,7 +332,7 @@ class JiugeBatchedTask:
self
.
tokens
=
(
c_uint
*
self
.
ntok
)(
*
flat_tokens
)
self
.
tokens
=
(
c_uint
*
self
.
ntok
)(
*
flat_tokens
)
self
.
req_lens
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_lens_list
)
self
.
req_lens
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_lens_list
)
self
.
req_pos
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_pos_list
)
self
.
req_pos
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_pos_list
)
self
.
kv_caches
=
(
POINTER
(
KVCache
)
*
self
.
nreq
)(
*
self
.
kv_cache_ptrs
)
self
.
kv_caches
=
(
POINTER
(
KVCache
CStruct
)
*
self
.
nreq
)(
*
self
.
kv_cache_ptrs
)
self
.
temperaturas
=
(
c_float
*
self
.
nreq
)(
*
self
.
temperaturas_list
)
self
.
temperaturas
=
(
c_float
*
self
.
nreq
)(
*
self
.
temperaturas_list
)
self
.
topks
=
(
c_uint
*
self
.
nreq
)(
*
self
.
topks_list
)
self
.
topks
=
(
c_uint
*
self
.
nreq
)(
*
self
.
topks_list
)
self
.
topps
=
(
c_float
*
self
.
nreq
)(
*
self
.
topps_list
)
self
.
topps
=
(
c_float
*
self
.
nreq
)(
*
self
.
topps_list
)
...
@@ -338,7 +361,7 @@ class JiugeForCauslLM:
...
@@ -338,7 +361,7 @@ class JiugeForCauslLM:
for
name_
in
data_
.
keys
():
for
name_
in
data_
.
keys
():
tensors_
[
name_
]
=
data_
.
get_tensor
(
name_
)
tensors_
[
name_
]
=
data_
.
get_tensor
(
name_
)
return
tensors_
return
tensors_
print
(
"Loading model weights to host..."
)
print
(
"Loading model weights to host..."
)
load_start_time
=
time
.
time
()
load_start_time
=
time
.
time
()
...
@@ -346,26 +369,46 @@ class JiugeForCauslLM:
...
@@ -346,26 +369,46 @@ class JiugeForCauslLM:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
self
.
config
=
config
self
.
config
=
config
eos_token_id
=
self
.
config
[
"eos_token_id"
]
eos_token_id
=
self
.
config
[
"eos_token_id"
]
self
.
eos_token_id
=
[
eos_token_id
]
if
type
(
eos_token_id
)
==
int
else
eos_token_id
self
.
eos_token_id
=
(
transpose_weight
=
device
!=
DeviceType
.
DEVICE_TYPE_ASCEND
# y = xW is faster than y=xW^T on Ascend
[
eos_token_id
]
if
type
(
eos_token_id
)
==
int
else
eos_token_id
)
transpose_weight
=
(
device
!=
DeviceType
.
DEVICE_TYPE_ASCEND
)
# y = xW is faster than y=xW^T on Ascend
if
"llama"
==
config
[
"model_type"
]:
if
"llama"
==
config
[
"model_type"
]:
model
=
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
).
cpu
().
half
()
model
=
(
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
)
.
cpu
()
.
half
()
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
)
elif
"fm9g"
==
config
[
"model_type"
]:
elif
"fm9g"
==
config
[
"model_type"
]:
if
any
(
file
.
suffix
==
".safetensors"
for
file
in
Path
(
model_dir_path
).
iterdir
()):
if
any
(
file
.
suffix
==
".safetensors"
for
file
in
Path
(
model_dir_path
).
iterdir
()
):
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
else
:
else
:
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
)
weights_only
=
True
,
map_location
=
"cpu"
,
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
model_dir_path
,
trust_remote_code
=
True
...
@@ -374,12 +417,18 @@ class JiugeForCauslLM:
...
@@ -374,12 +417,18 @@ class JiugeForCauslLM:
raise
ValueError
(
"Unsupported weight naming"
)
raise
ValueError
(
"Unsupported weight naming"
)
elif
"fm9g7b"
==
config
[
"model_type"
]:
elif
"fm9g7b"
==
config
[
"model_type"
]:
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
,
)
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
model_dir_path
,
trust_remote_code
=
True
...
@@ -391,7 +440,11 @@ class JiugeForCauslLM:
...
@@ -391,7 +440,11 @@ class JiugeForCauslLM:
if
LlamaWeightsNaming
.
match
(
state_dict
):
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
model_dir_path
...
@@ -401,7 +454,7 @@ class JiugeForCauslLM:
...
@@ -401,7 +454,7 @@ class JiugeForCauslLM:
load_end_time
=
time
.
time
()
load_end_time
=
time
.
time
()
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
print
(
f
"Creating model on
{
ndev
}
devices..."
)
print
(
f
"Creating model on
{
ndev
}
devices..."
)
load_start_time
=
time
.
time
()
load_start_time
=
time
.
time
()
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
...
@@ -414,10 +467,10 @@ class JiugeForCauslLM:
...
@@ -414,10 +467,10 @@ class JiugeForCauslLM:
)
)
load_end_time
=
time
.
time
()
load_end_time
=
time
.
time
()
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
def
max_context_len
(
self
):
def
max_context_len
(
self
):
return
self
.
meta
.
dctx
return
self
.
meta
.
dctx
def
create_kv_cache
(
self
):
def
create_kv_cache
(
self
):
return
create_kv_cache
(
self
.
model_instance
)
return
create_kv_cache
(
self
.
model_instance
)
...
@@ -435,7 +488,6 @@ class JiugeForCauslLM:
...
@@ -435,7 +488,6 @@ class JiugeForCauslLM:
return
list
(
output
)
return
list
(
output
)
def
generate
(
self
,
input_content
,
max_steps
,
topp_
=
1.0
,
topk_
=
1
,
temperature_
=
1.0
):
def
generate
(
self
,
input_content
,
max_steps
,
topp_
=
1.0
,
topk_
=
1
,
temperature_
=
1.0
):
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
input_content
=
self
.
tokenizer
.
apply_chat_template
(
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
add_generation_prompt
=
True
,
add_generation_prompt
=
True
,
...
@@ -443,39 +495,26 @@ class JiugeForCauslLM:
...
@@ -443,39 +495,26 @@ class JiugeForCauslLM:
)
)
print
(
input_content
,
end
=
""
,
flush
=
True
)
print
(
input_content
,
end
=
""
,
flush
=
True
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
ntok
=
len
(
tokens
)
infer_task
=
InferTask
(
nreq
=
1
0
,
output_content
=
""
tokens
,
tokens
=
(
c_uint
*
ntok
)(
*
tokens
)
self
.
max_context_len
(),
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
temperature_
,
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
topk_
,
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
topp_
,
ans
=
(
c_uint
*
nreq
)()
self
.
eos_token_id
,
temperature
=
(
c_float
*
nreq
)(
*
[
temperature_
])
)
topk
=
(
c_uint
*
nreq
)(
*
[
topk_
])
infer_task
.
bind_kvcache
(
KVCache
(
self
))
topp
=
(
c_float
*
nreq
)(
*
[
topp_
])
steps
=
0
steps
=
0
total_time
=
0
total_time
=
0
output_content
=
""
for
step_i
in
range
(
max_steps
):
for
step_i
in
range
(
max_steps
):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
infer_batch
(
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
self
.
model_instance
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
temperature
,
topk
,
topp
,
ans
,
)
steps
+=
1
output_tokens
=
list
(
ans
)
end_time
=
time
.
time
()
end_time
=
time
.
time
()
steps
+=
1
output_str
=
(
output_str
=
(
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"▁"
,
" "
)
...
@@ -485,10 +524,7 @@ class JiugeForCauslLM:
...
@@ -485,10 +524,7 @@ class JiugeForCauslLM:
print
(
output_str
,
end
=
""
,
flush
=
True
)
print
(
output_str
,
end
=
""
,
flush
=
True
)
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
break
break
req_pos
[
0
]
=
req_pos
[
0
]
+
ntok
infer_task
.
next
(
output_tokens
[
0
])
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
if
step_i
>
0
:
if
step_i
>
0
:
total_time
+=
end_time
-
start_time
total_time
+=
end_time
-
start_time
...
@@ -496,8 +532,8 @@ class JiugeForCauslLM:
...
@@ -496,8 +532,8 @@ class JiugeForCauslLM:
print
(
"
\n
"
)
print
(
"
\n
"
)
avg_time
=
total_time
*
1000
/
(
steps
-
1
)
avg_time
=
total_time
*
1000
/
(
steps
-
1
)
print
(
f
"Time per step:
{
avg_time
:.
3
f
}
ms"
)
print
(
f
"Time per step:
{
avg_time
:.
3
f
}
ms"
)
for
kv_cache
in
kv_caches
:
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
infer_task
.
_kv_cache
.
drop
(
self
)
return
output_content
,
avg_time
return
output_content
,
avg_time
def
destroy_model_instance
(
self
):
def
destroy_model_instance
(
self
):
...
...
scripts/kvcache_pool.py
View file @
2b9ce5a6
from
infer_task
import
KVCache
import
asyncio
import
asyncio
from
typing
import
List
from
typing
import
List
class
KVCachePoolItem
:
def
__init__
(
self
,
model
):
self
.
kvcache
=
model
.
create_kv_cache
()
self
.
tokens
=
[
0
for
_
in
range
(
model
.
max_context_len
())]
def
drop
(
self
,
model
):
model
.
drop_kv_cache
(
self
.
kvcache
)
def
update_tokens
(
self
,
tokens
,
pos
):
end
=
pos
+
len
(
tokens
)
max_len
=
len
(
self
.
tokens
)
# If overflow, truncate tokens to fit
if
end
>
max_len
:
tokens
=
tokens
[:
max_len
-
pos
]
end
=
max_len
self
.
tokens
[
pos
:
end
]
=
tokens
import
threading
import
threading
...
@@ -29,7 +9,7 @@ class KVCachePool:
...
@@ -29,7 +9,7 @@ class KVCachePool:
def
__init__
(
self
,
model
,
max_caches
:
int
=
32
):
def
__init__
(
self
,
model
,
max_caches
:
int
=
32
):
self
.
max_caches
=
max_caches
self
.
max_caches
=
max_caches
self
.
model
=
model
self
.
model
=
model
self
.
_available
:
List
[
KVCache
PoolItem
]
=
[]
self
.
_available
:
List
[
KVCache
]
=
[]
self
.
num_caches
=
len
(
self
.
_available
)
self
.
num_caches
=
len
(
self
.
_available
)
self
.
_lock
=
threading
.
Lock
()
self
.
_lock
=
threading
.
Lock
()
self
.
_not_empty
=
threading
.
Condition
(
self
.
_lock
)
self
.
_not_empty
=
threading
.
Condition
(
self
.
_lock
)
...
@@ -45,8 +25,10 @@ class KVCachePool:
...
@@ -45,8 +25,10 @@ class KVCachePool:
if
len
(
self
.
_available
)
==
0
:
if
len
(
self
.
_available
)
==
0
:
if
self
.
num_caches
<
self
.
max_caches
:
if
self
.
num_caches
<
self
.
max_caches
:
self
.
num_caches
+=
1
self
.
num_caches
+=
1
print
(
f
"[INFO] Task
{
infer_task
.
id
}
created new KVCachePoolItem"
)
print
(
return
infer_task
.
bind_kvcache
(
KVCachePoolItem
(
self
.
model
),
0
)
f
"[INFO] Task
{
infer_task
.
id
}
created new KVCachePoolItem"
)
return
infer_task
.
bind_kvcache
(
KVCache
(
self
.
model
),
0
)
else
:
else
:
self
.
_not_empty
.
wait
()
self
.
_not_empty
.
wait
()
else
:
else
:
...
@@ -62,8 +44,7 @@ class KVCachePool:
...
@@ -62,8 +44,7 @@ class KVCachePool:
def
release_sync
(
self
,
infer_task
):
def
release_sync
(
self
,
infer_task
):
with
self
.
_not_empty
:
with
self
.
_not_empty
:
print
(
f
"[INFO] Task
{
infer_task
.
id
}
returned KVCachePoolItem to pool"
)
print
(
f
"[INFO] Task
{
infer_task
.
id
}
returned KVCachePoolItem to pool"
)
self
.
_available
.
append
(
infer_task
.
_kv_cache_pool_item
)
self
.
_available
.
append
(
infer_task
.
release_kvcache
())
infer_task
.
_kv_cache_pool_item
=
None
self
.
_not_empty
.
notify
()
self
.
_not_empty
.
notify
()
async
def
acquire
(
self
,
infer_task
):
async
def
acquire
(
self
,
infer_task
):
...
...
scripts/launch_server.py
View file @
2b9ce5a6
...
@@ -70,6 +70,17 @@ print(
...
@@ -70,6 +70,17 @@ print(
f
"Using MAX_BATCH=
{
MAX_BATCH
}
. Try reduce this value if out of memory error occurs."
f
"Using MAX_BATCH=
{
MAX_BATCH
}
. Try reduce this value if out of memory error occurs."
)
)
# A wrapper for InferTask that supports async output queue
class
AsyncInferTask
(
InferTask
):
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
super
().
__init__
(
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
)
self
.
output_queue
=
janus
.
Queue
()
print
(
f
"[INFO] Create InferTask
{
self
.
id
}
"
)
def
output
(
self
,
out_token
):
self
.
next
(
out_token
)
self
.
output_queue
.
sync_q
.
put
(
out_token
)
@
contextlib
.
asynccontextmanager
@
contextlib
.
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
async
def
lifespan
(
app
:
FastAPI
):
...
@@ -132,7 +143,7 @@ def build_task(id_, request_data, request: Request):
...
@@ -132,7 +143,7 @@ def build_task(id_, request_data, request: Request):
tokenize
=
False
,
tokenize
=
False
,
)
)
tokens
=
request
.
app
.
state
.
model
.
tokenizer
.
encode
(
input_content
)
tokens
=
request
.
app
.
state
.
model
.
tokenizer
.
encode
(
input_content
)
return
InferTask
(
return
Async
InferTask
(
id_
,
id_
,
tokens
,
tokens
,
request_data
.
get
(
"max_tokens"
,
request
.
app
.
state
.
model
.
max_context_len
()),
request_data
.
get
(
"max_tokens"
,
request
.
app
.
state
.
model
.
max_context_len
()),
...
...
scripts/libinfinicore_infer.py
View file @
2b9ce5a6
...
@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int):
...
@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int):
DEVICE_TYPE_MOORE
=
5
DEVICE_TYPE_MOORE
=
5
class
JiugeMeta
(
ctypes
.
Structure
):
class
JiugeMeta
CStruct
(
ctypes
.
Structure
):
_fields_
=
[
_fields_
=
[
(
"dt_logits"
,
DataType
),
(
"dt_logits"
,
DataType
),
(
"nlayer"
,
c_size_t
),
(
"nlayer"
,
c_size_t
),
...
@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure):
...
@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure):
# Define the JiugeWeights struct
# Define the JiugeWeights struct
class
JiugeWeights
(
ctypes
.
Structure
):
class
JiugeWeights
CStruct
(
ctypes
.
Structure
):
_fields_
=
[
_fields_
=
[
(
"nlayer"
,
c_size_t
),
(
"nlayer"
,
c_size_t
),
(
"dt_norm"
,
DataType
),
(
"dt_norm"
,
DataType
),
...
@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure):
...
@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure):
]
]
class
JiugeModel
(
ctypes
.
Structure
):
class
JiugeModel
CSruct
(
ctypes
.
Structure
):
pass
pass
class
KVCache
(
ctypes
.
Structure
):
class
KVCache
CStruct
(
ctypes
.
Structure
):
pass
pass
...
@@ -85,27 +85,27 @@ def __open_library__():
...
@@ -85,27 +85,27 @@ def __open_library__():
os
.
environ
.
get
(
"INFINI_ROOT"
),
"lib"
,
"libinfinicore_infer.so"
os
.
environ
.
get
(
"INFINI_ROOT"
),
"lib"
,
"libinfinicore_infer.so"
)
)
lib
=
ctypes
.
CDLL
(
lib_path
)
lib
=
ctypes
.
CDLL
(
lib_path
)
lib
.
createJiugeModel
.
restype
=
POINTER
(
JiugeModel
)
lib
.
createJiugeModel
.
restype
=
POINTER
(
JiugeModel
CSruct
)
lib
.
createJiugeModel
.
argtypes
=
[
lib
.
createJiugeModel
.
argtypes
=
[
POINTER
(
JiugeMeta
),
# JiugeMeta const *
POINTER
(
JiugeMeta
CStruct
),
# JiugeMeta const *
POINTER
(
JiugeWeights
),
# JiugeWeights const *
POINTER
(
JiugeWeights
CStruct
),
# JiugeWeights const *
DeviceType
,
# DeviceType
DeviceType
,
# DeviceType
c_int
,
# int ndev
c_int
,
# int ndev
POINTER
(
c_int
),
# int const *dev_ids
POINTER
(
c_int
),
# int const *dev_ids
]
]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
CSruct
)]
lib
.
createKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
createKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
CSruct
)]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
)
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
CStruct
)
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
),
POINTER
(
KVCache
)]
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
CSruct
),
POINTER
(
KVCache
CStruct
)]
lib
.
inferBatch
.
restype
=
None
lib
.
inferBatch
.
restype
=
None
lib
.
inferBatch
.
argtypes
=
[
lib
.
inferBatch
.
argtypes
=
[
POINTER
(
JiugeModel
),
# struct JiugeModel const *
POINTER
(
JiugeModel
CSruct
),
# struct JiugeModel const *
POINTER
(
c_uint
),
# unsigned int const *tokens
POINTER
(
c_uint
),
# unsigned int const *tokens
c_uint
,
# unsigned int ntok
c_uint
,
# unsigned int ntok
POINTER
(
c_uint
),
# unsigned int const *req_lens
POINTER
(
c_uint
),
# unsigned int const *req_lens
c_uint
,
# unsigned int nreq
c_uint
,
# unsigned int nreq
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
POINTER
(
KVCache
)),
# struct KVCache **kv_caches
POINTER
(
POINTER
(
KVCache
CStruct
)),
# struct KVCache **kv_caches
POINTER
(
c_float
),
# float temperature
POINTER
(
c_float
),
# float temperature
POINTER
(
c_uint
),
# unsigned int topk
POINTER
(
c_uint
),
# unsigned int topk
POINTER
(
c_float
),
# float topp
POINTER
(
c_float
),
# float topp
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment