Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
2b9ce5a6
"git@developer.sourcefind.cn:jerrrrry/infinilm.git" did not exist on "6ae4832227c8d6a1f20f9da76685f9fd74138468"
Commit
2b9ce5a6
authored
Jun 25, 2025
by
Pan Zezhong
Browse files
refactor task to avoid model dependency on async libs
parent
cfc8b598
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
171 additions
and
122 deletions
+171
-122
scripts/infer_task.py
scripts/infer_task.py
+33
-12
scripts/jiuge.py
scripts/jiuge.py
+105
-69
scripts/kvcache_pool.py
scripts/kvcache_pool.py
+8
-27
scripts/launch_server.py
scripts/launch_server.py
+12
-1
scripts/libinfinicore_infer.py
scripts/libinfinicore_infer.py
+13
-13
No files found.
scripts/infer_task.py
View file @
2b9ce5a6
import
janus
class
InferTask
:
class
InferTask
:
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
self
.
id
=
id
self
.
id
=
id
...
@@ -11,21 +8,24 @@ class InferTask:
...
@@ -11,21 +8,24 @@ class InferTask:
self
.
topk
=
topk
self
.
topk
=
topk
self
.
topp
=
topp
self
.
topp
=
topp
self
.
end_tokens
=
end_tokens
self
.
end_tokens
=
end_tokens
self
.
output_queue
=
janus
.
Queue
()
self
.
_kv_cache
=
None
self
.
_kv_cache_pool_item
=
None
self
.
pos
=
0
self
.
pos
=
0
print
(
f
"[INFO] Create InferTask
{
self
.
id
}
"
)
def
bind_kvcache
(
self
,
kv_cache
_pool_item
,
pos
):
def
bind_kvcache
(
self
,
kv_cache
,
pos
=
0
):
self
.
_kv_cache
_pool_item
=
kv_cache
_pool_item
self
.
_kv_cache
=
kv_cache
self
.
pos
=
pos
self
.
pos
=
pos
self
.
tokens
=
self
.
tokens
[
pos
:]
self
.
tokens
=
self
.
tokens
[
pos
:]
def
release_kvcache
(
self
):
cache
=
self
.
_kv_cache
self
.
_kv_cache
=
None
return
cache
def
kvcache
(
self
):
def
kvcache
(
self
):
return
self
.
_kv_cache
_pool_item
.
kvcache
return
self
.
_kv_cache
def
outpu
t
(
self
,
out_token
):
def
nex
t
(
self
,
out_token
):
self
.
_kv_cache
_pool_item
.
update_tokens
(
self
.
tokens
,
self
.
pos
)
self
.
_kv_cache
.
update_tokens
(
self
.
tokens
,
self
.
pos
)
self
.
pos
+=
len
(
self
.
tokens
)
self
.
pos
+=
len
(
self
.
tokens
)
if
out_token
==
None
or
out_token
in
self
.
end_tokens
:
if
out_token
==
None
or
out_token
in
self
.
end_tokens
:
...
@@ -35,4 +35,25 @@ class InferTask:
...
@@ -35,4 +35,25 @@ class InferTask:
else
:
else
:
self
.
tokens
=
[
out_token
]
self
.
tokens
=
[
out_token
]
self
.
output_queue
.
sync_q
.
put
(
out_token
)
class
KVCache
:
def
__init__
(
self
,
model
):
self
.
_kvcache
=
model
.
create_kv_cache
()
self
.
tokens
=
[
0
for
_
in
range
(
model
.
max_context_len
())]
def
data
(
self
):
return
self
.
_kvcache
def
drop
(
self
,
model
):
model
.
drop_kv_cache
(
self
.
_kvcache
)
def
update_tokens
(
self
,
tokens
,
pos
):
end
=
pos
+
len
(
tokens
)
max_len
=
len
(
self
.
tokens
)
# If overflow, truncate tokens to fit
if
end
>
max_len
:
tokens
=
tokens
[:
max_len
-
pos
]
end
=
max_len
self
.
tokens
[
pos
:
end
]
=
tokens
scripts/jiuge.py
View file @
2b9ce5a6
from
typing
import
List
from
typing
import
List
from
libinfinicore_infer
import
(
from
libinfinicore_infer
import
(
JiugeMeta
,
JiugeMeta
CStruct
,
JiugeWeights
,
JiugeWeights
CStruct
,
KVCache
,
KVCache
CStruct
,
DataType
,
DataType
,
DeviceType
,
DeviceType
,
create_jiuge_model
,
create_jiuge_model
,
...
@@ -11,7 +11,7 @@ from libinfinicore_infer import (
...
@@ -11,7 +11,7 @@ from libinfinicore_infer import (
drop_kv_cache
,
drop_kv_cache
,
infer_batch
,
infer_batch
,
)
)
from
infer_task
import
InferTask
from
infer_task
import
InferTask
,
KVCache
from
ctypes
import
POINTER
,
c_float
,
c_int
,
c_uint
,
c_void_p
,
byref
from
ctypes
import
POINTER
,
c_float
,
c_int
,
c_uint
,
c_void_p
,
byref
import
os
import
os
...
@@ -25,6 +25,7 @@ import transformers
...
@@ -25,6 +25,7 @@ import transformers
torch
.
set_default_device
(
"cpu"
)
torch
.
set_default_device
(
"cpu"
)
class
LlamaWeightsNaming
:
class
LlamaWeightsNaming
:
def
input_embd
(
self
):
def
input_embd
(
self
):
return
"model.embed_tokens.weight"
return
"model.embed_tokens.weight"
...
@@ -78,7 +79,7 @@ class LlamaWeightsNaming:
...
@@ -78,7 +79,7 @@ class LlamaWeightsNaming:
)
)
class
JiugeMetaFromLlama
(
JiugeMeta
):
class
JiugeMetaFromLlama
(
JiugeMeta
CStruct
):
def
__init__
(
self
,
config
,
dtype
=
torch
.
float16
):
def
__init__
(
self
,
config
,
dtype
=
torch
.
float16
):
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
dt_
=
DataType
.
INFINI_DTYPE_F16
dt_
=
DataType
.
INFINI_DTYPE_F16
...
@@ -107,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta):
...
@@ -107,7 +108,7 @@ class JiugeMetaFromLlama(JiugeMeta):
self
.
torch_dtype_logits
=
dtype
self
.
torch_dtype_logits
=
dtype
class
JiugeWeightsImpl
(
JiugeWeights
):
class
JiugeWeightsImpl
(
JiugeWeights
CStruct
):
def
__init__
(
def
__init__
(
self
,
self
,
meta
,
meta
,
...
@@ -160,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -160,7 +161,9 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_embd_tensor
=
state_dict
[
output_embd_naming
].
to
(
torch_dt_mat
)
self
.
output_embd_tensor
=
state_dict
[
output_embd_naming
].
to
(
torch_dt_mat
)
if
not
transpose_weight
:
if
not
transpose_weight
:
self
.
output_embd_tensor
=
self
.
output_embd_tensor
.
transpose
(
0
,
1
).
contiguous
()
self
.
output_embd_tensor
=
self
.
output_embd_tensor
.
transpose
(
0
,
1
).
contiguous
()
self
.
output_embd
=
self
.
output_embd_tensor
.
data_ptr
()
self
.
output_embd
=
self
.
output_embd_tensor
.
data_ptr
()
self
.
attn_norm_tensors
=
[
self
.
attn_norm_tensors
=
[
...
@@ -197,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -197,7 +200,12 @@ class JiugeWeightsImpl(JiugeWeights):
]
]
if
not
transpose_weight
:
if
not
transpose_weight
:
for
i
in
range
(
nlayer
):
for
i
in
range
(
nlayer
):
self
.
qkv_tensor
[
i
]
=
self
.
qkv_tensor
[
i
].
reshape
(
ndev
,
(
nh
+
2
*
nkvh
)
//
ndev
*
dh
,
d
).
transpose
(
1
,
2
).
contiguous
()
self
.
qkv_tensor
[
i
]
=
(
self
.
qkv_tensor
[
i
]
.
reshape
(
ndev
,
(
nh
+
2
*
nkvh
)
//
ndev
*
dh
,
d
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
self
.
qkv_tensor_ptrs
=
[
self
.
qkv_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
qkv_tensor_ptrs
=
[
self
.
qkv_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_qkv
=
(
c_void_p
*
nlayer
)(
*
self
.
qkv_tensor_ptrs
)
self
.
attn_qkv
=
(
c_void_p
*
nlayer
)(
*
self
.
qkv_tensor_ptrs
)
...
@@ -234,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -234,13 +242,18 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
attn_qkv_b
=
None
self
.
attn_qkv_b
=
None
self
.
attn_o_tensor
=
[
self
.
attn_o_tensor
=
[
state_dict
[
naming
.
attn_o
(
i
)]
(
state_dict
[
naming
.
attn_o
(
i
)]
.
to
(
torch_dt_mat
)
.
to
(
torch_dt_mat
)
.
reshape
([
d
,
ndev
,
nh
//
ndev
*
dh
])
.
reshape
([
d
,
ndev
,
nh
//
ndev
*
dh
])
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
contiguous
()
if
transpose_weight
if
transpose_weight
else
state_dict
[
naming
.
attn_o
(
i
)].
transpose
(
0
,
1
).
to
(
torch_dt_mat
).
contiguous
()
else
state_dict
[
naming
.
attn_o
(
i
)]
.
transpose
(
0
,
1
)
.
to
(
torch_dt_mat
)
.
contiguous
()
)
for
i
in
range
(
nlayer
)
for
i
in
range
(
nlayer
)
]
]
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
...
@@ -269,18 +282,28 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -269,18 +282,28 @@ class JiugeWeightsImpl(JiugeWeights):
]
]
if
not
transpose_weight
:
if
not
transpose_weight
:
for
i
in
range
(
nlayer
):
for
i
in
range
(
nlayer
):
self
.
gate_up_tensors
[
i
]
=
self
.
gate_up_tensors
[
i
].
reshape
(
ndev
,
2
*
di
//
ndev
,
d
).
transpose
(
1
,
2
).
contiguous
()
self
.
gate_up_tensors
[
i
]
=
(
self
.
gate_up_tensors
[
i
]
.
reshape
(
ndev
,
2
*
di
//
ndev
,
d
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
self
.
gate_up_ptrs
=
[
self
.
gate_up_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
gate_up_ptrs
=
[
self
.
gate_up_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
ffn_gate_up
=
(
c_void_p
*
nlayer
)(
*
self
.
gate_up_ptrs
)
self
.
ffn_gate_up
=
(
c_void_p
*
nlayer
)(
*
self
.
gate_up_ptrs
)
self
.
ffn_down_tensor
=
[
self
.
ffn_down_tensor
=
[
state_dict
[
naming
.
down
(
i
)]
(
.
to
(
torch_dt_mat
)
state_dict
[
naming
.
down
(
i
)]
.
reshape
([
d
,
ndev
,
di
//
ndev
])
.
to
(
torch_dt_mat
)
.
transpose
(
0
,
1
)
.
reshape
([
d
,
ndev
,
di
//
ndev
])
.
contiguous
()
.
transpose
(
0
,
1
)
if
transpose_weight
.
contiguous
()
else
state_dict
[
naming
.
down
(
i
)].
transpose
(
0
,
1
).
to
(
torch_dt_mat
).
contiguous
()
if
transpose_weight
else
state_dict
[
naming
.
down
(
i
)]
.
transpose
(
0
,
1
)
.
to
(
torch_dt_mat
)
.
contiguous
()
)
for
i
in
range
(
nlayer
)
for
i
in
range
(
nlayer
)
]
]
self
.
ffn_down_ptrs
=
[
self
.
ffn_down_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
ffn_down_ptrs
=
[
self
.
ffn_down_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
...
@@ -296,7 +319,7 @@ class JiugeBatchedTask:
...
@@ -296,7 +319,7 @@ class JiugeBatchedTask:
token_lists
=
[
t
.
tokens
for
t
in
tasks
]
token_lists
=
[
t
.
tokens
for
t
in
tasks
]
self
.
req_lens_list
=
[
len
(
toks
)
for
toks
in
token_lists
]
self
.
req_lens_list
=
[
len
(
toks
)
for
toks
in
token_lists
]
self
.
req_pos_list
=
[
t
.
pos
for
t
in
tasks
]
self
.
req_pos_list
=
[
t
.
pos
for
t
in
tasks
]
self
.
kv_cache_ptrs
=
[
t
.
kvcache
()
for
t
in
tasks
]
self
.
kv_cache_ptrs
=
[
t
.
kvcache
()
.
data
()
for
t
in
tasks
]
self
.
temperaturas_list
=
[
t
.
temperature
for
t
in
tasks
]
self
.
temperaturas_list
=
[
t
.
temperature
for
t
in
tasks
]
self
.
topks_list
=
[
t
.
topk
for
t
in
tasks
]
self
.
topks_list
=
[
t
.
topk
for
t
in
tasks
]
self
.
topps_list
=
[
t
.
topp
for
t
in
tasks
]
self
.
topps_list
=
[
t
.
topp
for
t
in
tasks
]
...
@@ -309,7 +332,7 @@ class JiugeBatchedTask:
...
@@ -309,7 +332,7 @@ class JiugeBatchedTask:
self
.
tokens
=
(
c_uint
*
self
.
ntok
)(
*
flat_tokens
)
self
.
tokens
=
(
c_uint
*
self
.
ntok
)(
*
flat_tokens
)
self
.
req_lens
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_lens_list
)
self
.
req_lens
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_lens_list
)
self
.
req_pos
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_pos_list
)
self
.
req_pos
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_pos_list
)
self
.
kv_caches
=
(
POINTER
(
KVCache
)
*
self
.
nreq
)(
*
self
.
kv_cache_ptrs
)
self
.
kv_caches
=
(
POINTER
(
KVCache
CStruct
)
*
self
.
nreq
)(
*
self
.
kv_cache_ptrs
)
self
.
temperaturas
=
(
c_float
*
self
.
nreq
)(
*
self
.
temperaturas_list
)
self
.
temperaturas
=
(
c_float
*
self
.
nreq
)(
*
self
.
temperaturas_list
)
self
.
topks
=
(
c_uint
*
self
.
nreq
)(
*
self
.
topks_list
)
self
.
topks
=
(
c_uint
*
self
.
nreq
)(
*
self
.
topks_list
)
self
.
topps
=
(
c_float
*
self
.
nreq
)(
*
self
.
topps_list
)
self
.
topps
=
(
c_float
*
self
.
nreq
)(
*
self
.
topps_list
)
...
@@ -338,7 +361,7 @@ class JiugeForCauslLM:
...
@@ -338,7 +361,7 @@ class JiugeForCauslLM:
for
name_
in
data_
.
keys
():
for
name_
in
data_
.
keys
():
tensors_
[
name_
]
=
data_
.
get_tensor
(
name_
)
tensors_
[
name_
]
=
data_
.
get_tensor
(
name_
)
return
tensors_
return
tensors_
print
(
"Loading model weights to host..."
)
print
(
"Loading model weights to host..."
)
load_start_time
=
time
.
time
()
load_start_time
=
time
.
time
()
...
@@ -346,26 +369,46 @@ class JiugeForCauslLM:
...
@@ -346,26 +369,46 @@ class JiugeForCauslLM:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
self
.
config
=
config
self
.
config
=
config
eos_token_id
=
self
.
config
[
"eos_token_id"
]
eos_token_id
=
self
.
config
[
"eos_token_id"
]
self
.
eos_token_id
=
[
eos_token_id
]
if
type
(
eos_token_id
)
==
int
else
eos_token_id
self
.
eos_token_id
=
(
transpose_weight
=
device
!=
DeviceType
.
DEVICE_TYPE_ASCEND
# y = xW is faster than y=xW^T on Ascend
[
eos_token_id
]
if
type
(
eos_token_id
)
==
int
else
eos_token_id
)
transpose_weight
=
(
device
!=
DeviceType
.
DEVICE_TYPE_ASCEND
)
# y = xW is faster than y=xW^T on Ascend
if
"llama"
==
config
[
"model_type"
]:
if
"llama"
==
config
[
"model_type"
]:
model
=
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
).
cpu
().
half
()
model
=
(
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
)
.
cpu
()
.
half
()
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
)
elif
"fm9g"
==
config
[
"model_type"
]:
elif
"fm9g"
==
config
[
"model_type"
]:
if
any
(
file
.
suffix
==
".safetensors"
for
file
in
Path
(
model_dir_path
).
iterdir
()):
if
any
(
file
.
suffix
==
".safetensors"
for
file
in
Path
(
model_dir_path
).
iterdir
()
):
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
else
:
else
:
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
)
weights_only
=
True
,
map_location
=
"cpu"
,
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
model_dir_path
,
trust_remote_code
=
True
...
@@ -374,12 +417,18 @@ class JiugeForCauslLM:
...
@@ -374,12 +417,18 @@ class JiugeForCauslLM:
raise
ValueError
(
"Unsupported weight naming"
)
raise
ValueError
(
"Unsupported weight naming"
)
elif
"fm9g7b"
==
config
[
"model_type"
]:
elif
"fm9g7b"
==
config
[
"model_type"
]:
state_dict
=
torch
.
load
(
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
,
)
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
model_dir_path
,
trust_remote_code
=
True
...
@@ -391,7 +440,11 @@ class JiugeForCauslLM:
...
@@ -391,7 +440,11 @@ class JiugeForCauslLM:
if
LlamaWeightsNaming
.
match
(
state_dict
):
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
,
transpose_weight
=
transpose_weight
,
)
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
model_dir_path
...
@@ -401,7 +454,7 @@ class JiugeForCauslLM:
...
@@ -401,7 +454,7 @@ class JiugeForCauslLM:
load_end_time
=
time
.
time
()
load_end_time
=
time
.
time
()
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
print
(
f
"Creating model on
{
ndev
}
devices..."
)
print
(
f
"Creating model on
{
ndev
}
devices..."
)
load_start_time
=
time
.
time
()
load_start_time
=
time
.
time
()
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
...
@@ -414,10 +467,10 @@ class JiugeForCauslLM:
...
@@ -414,10 +467,10 @@ class JiugeForCauslLM:
)
)
load_end_time
=
time
.
time
()
load_end_time
=
time
.
time
()
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
def
max_context_len
(
self
):
def
max_context_len
(
self
):
return
self
.
meta
.
dctx
return
self
.
meta
.
dctx
def
create_kv_cache
(
self
):
def
create_kv_cache
(
self
):
return
create_kv_cache
(
self
.
model_instance
)
return
create_kv_cache
(
self
.
model_instance
)
...
@@ -435,7 +488,6 @@ class JiugeForCauslLM:
...
@@ -435,7 +488,6 @@ class JiugeForCauslLM:
return
list
(
output
)
return
list
(
output
)
def
generate
(
self
,
input_content
,
max_steps
,
topp_
=
1.0
,
topk_
=
1
,
temperature_
=
1.0
):
def
generate
(
self
,
input_content
,
max_steps
,
topp_
=
1.0
,
topk_
=
1
,
temperature_
=
1.0
):
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
input_content
=
self
.
tokenizer
.
apply_chat_template
(
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
add_generation_prompt
=
True
,
add_generation_prompt
=
True
,
...
@@ -443,39 +495,26 @@ class JiugeForCauslLM:
...
@@ -443,39 +495,26 @@ class JiugeForCauslLM:
)
)
print
(
input_content
,
end
=
""
,
flush
=
True
)
print
(
input_content
,
end
=
""
,
flush
=
True
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
ntok
=
len
(
tokens
)
infer_task
=
InferTask
(
nreq
=
1
0
,
output_content
=
""
tokens
,
tokens
=
(
c_uint
*
ntok
)(
*
tokens
)
self
.
max_context_len
(),
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
temperature_
,
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
topk_
,
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
topp_
,
ans
=
(
c_uint
*
nreq
)()
self
.
eos_token_id
,
temperature
=
(
c_float
*
nreq
)(
*
[
temperature_
])
)
topk
=
(
c_uint
*
nreq
)(
*
[
topk_
])
infer_task
.
bind_kvcache
(
KVCache
(
self
))
topp
=
(
c_float
*
nreq
)(
*
[
topp_
])
steps
=
0
steps
=
0
total_time
=
0
total_time
=
0
output_content
=
""
for
step_i
in
range
(
max_steps
):
for
step_i
in
range
(
max_steps
):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
infer_batch
(
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
self
.
model_instance
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
temperature
,
topk
,
topp
,
ans
,
)
steps
+=
1
output_tokens
=
list
(
ans
)
end_time
=
time
.
time
()
end_time
=
time
.
time
()
steps
+=
1
output_str
=
(
output_str
=
(
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"▁"
,
" "
)
...
@@ -485,10 +524,7 @@ class JiugeForCauslLM:
...
@@ -485,10 +524,7 @@ class JiugeForCauslLM:
print
(
output_str
,
end
=
""
,
flush
=
True
)
print
(
output_str
,
end
=
""
,
flush
=
True
)
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
break
break
req_pos
[
0
]
=
req_pos
[
0
]
+
ntok
infer_task
.
next
(
output_tokens
[
0
])
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
if
step_i
>
0
:
if
step_i
>
0
:
total_time
+=
end_time
-
start_time
total_time
+=
end_time
-
start_time
...
@@ -496,8 +532,8 @@ class JiugeForCauslLM:
...
@@ -496,8 +532,8 @@ class JiugeForCauslLM:
print
(
"
\n
"
)
print
(
"
\n
"
)
avg_time
=
total_time
*
1000
/
(
steps
-
1
)
avg_time
=
total_time
*
1000
/
(
steps
-
1
)
print
(
f
"Time per step:
{
avg_time
:.
3
f
}
ms"
)
print
(
f
"Time per step:
{
avg_time
:.
3
f
}
ms"
)
for
kv_cache
in
kv_caches
:
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
infer_task
.
_kv_cache
.
drop
(
self
)
return
output_content
,
avg_time
return
output_content
,
avg_time
def
destroy_model_instance
(
self
):
def
destroy_model_instance
(
self
):
...
...
scripts/kvcache_pool.py
View file @
2b9ce5a6
from
infer_task
import
KVCache
import
asyncio
import
asyncio
from
typing
import
List
from
typing
import
List
class
KVCachePoolItem
:
def
__init__
(
self
,
model
):
self
.
kvcache
=
model
.
create_kv_cache
()
self
.
tokens
=
[
0
for
_
in
range
(
model
.
max_context_len
())]
def
drop
(
self
,
model
):
model
.
drop_kv_cache
(
self
.
kvcache
)
def
update_tokens
(
self
,
tokens
,
pos
):
end
=
pos
+
len
(
tokens
)
max_len
=
len
(
self
.
tokens
)
# If overflow, truncate tokens to fit
if
end
>
max_len
:
tokens
=
tokens
[:
max_len
-
pos
]
end
=
max_len
self
.
tokens
[
pos
:
end
]
=
tokens
import
threading
import
threading
...
@@ -29,7 +9,7 @@ class KVCachePool:
...
@@ -29,7 +9,7 @@ class KVCachePool:
def
__init__
(
self
,
model
,
max_caches
:
int
=
32
):
def
__init__
(
self
,
model
,
max_caches
:
int
=
32
):
self
.
max_caches
=
max_caches
self
.
max_caches
=
max_caches
self
.
model
=
model
self
.
model
=
model
self
.
_available
:
List
[
KVCache
PoolItem
]
=
[]
self
.
_available
:
List
[
KVCache
]
=
[]
self
.
num_caches
=
len
(
self
.
_available
)
self
.
num_caches
=
len
(
self
.
_available
)
self
.
_lock
=
threading
.
Lock
()
self
.
_lock
=
threading
.
Lock
()
self
.
_not_empty
=
threading
.
Condition
(
self
.
_lock
)
self
.
_not_empty
=
threading
.
Condition
(
self
.
_lock
)
...
@@ -45,8 +25,10 @@ class KVCachePool:
...
@@ -45,8 +25,10 @@ class KVCachePool:
if
len
(
self
.
_available
)
==
0
:
if
len
(
self
.
_available
)
==
0
:
if
self
.
num_caches
<
self
.
max_caches
:
if
self
.
num_caches
<
self
.
max_caches
:
self
.
num_caches
+=
1
self
.
num_caches
+=
1
print
(
f
"[INFO] Task
{
infer_task
.
id
}
created new KVCachePoolItem"
)
print
(
return
infer_task
.
bind_kvcache
(
KVCachePoolItem
(
self
.
model
),
0
)
f
"[INFO] Task
{
infer_task
.
id
}
created new KVCachePoolItem"
)
return
infer_task
.
bind_kvcache
(
KVCache
(
self
.
model
),
0
)
else
:
else
:
self
.
_not_empty
.
wait
()
self
.
_not_empty
.
wait
()
else
:
else
:
...
@@ -62,8 +44,7 @@ class KVCachePool:
...
@@ -62,8 +44,7 @@ class KVCachePool:
def
release_sync
(
self
,
infer_task
):
def
release_sync
(
self
,
infer_task
):
with
self
.
_not_empty
:
with
self
.
_not_empty
:
print
(
f
"[INFO] Task
{
infer_task
.
id
}
returned KVCachePoolItem to pool"
)
print
(
f
"[INFO] Task
{
infer_task
.
id
}
returned KVCachePoolItem to pool"
)
self
.
_available
.
append
(
infer_task
.
_kv_cache_pool_item
)
self
.
_available
.
append
(
infer_task
.
release_kvcache
())
infer_task
.
_kv_cache_pool_item
=
None
self
.
_not_empty
.
notify
()
self
.
_not_empty
.
notify
()
async
def
acquire
(
self
,
infer_task
):
async
def
acquire
(
self
,
infer_task
):
...
...
scripts/launch_server.py
View file @
2b9ce5a6
...
@@ -70,6 +70,17 @@ print(
...
@@ -70,6 +70,17 @@ print(
f
"Using MAX_BATCH=
{
MAX_BATCH
}
. Try reduce this value if out of memory error occurs."
f
"Using MAX_BATCH=
{
MAX_BATCH
}
. Try reduce this value if out of memory error occurs."
)
)
# A wrapper for InferTask that supports async output queue
class
AsyncInferTask
(
InferTask
):
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
super
().
__init__
(
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
)
self
.
output_queue
=
janus
.
Queue
()
print
(
f
"[INFO] Create InferTask
{
self
.
id
}
"
)
def
output
(
self
,
out_token
):
self
.
next
(
out_token
)
self
.
output_queue
.
sync_q
.
put
(
out_token
)
@
contextlib
.
asynccontextmanager
@
contextlib
.
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
async
def
lifespan
(
app
:
FastAPI
):
...
@@ -132,7 +143,7 @@ def build_task(id_, request_data, request: Request):
...
@@ -132,7 +143,7 @@ def build_task(id_, request_data, request: Request):
tokenize
=
False
,
tokenize
=
False
,
)
)
tokens
=
request
.
app
.
state
.
model
.
tokenizer
.
encode
(
input_content
)
tokens
=
request
.
app
.
state
.
model
.
tokenizer
.
encode
(
input_content
)
return
InferTask
(
return
Async
InferTask
(
id_
,
id_
,
tokens
,
tokens
,
request_data
.
get
(
"max_tokens"
,
request
.
app
.
state
.
model
.
max_context_len
()),
request_data
.
get
(
"max_tokens"
,
request
.
app
.
state
.
model
.
max_context_len
()),
...
...
scripts/libinfinicore_infer.py
View file @
2b9ce5a6
...
@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int):
...
@@ -35,7 +35,7 @@ class DeviceType(ctypes.c_int):
DEVICE_TYPE_MOORE
=
5
DEVICE_TYPE_MOORE
=
5
class
JiugeMeta
(
ctypes
.
Structure
):
class
JiugeMeta
CStruct
(
ctypes
.
Structure
):
_fields_
=
[
_fields_
=
[
(
"dt_logits"
,
DataType
),
(
"dt_logits"
,
DataType
),
(
"nlayer"
,
c_size_t
),
(
"nlayer"
,
c_size_t
),
...
@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure):
...
@@ -53,7 +53,7 @@ class JiugeMeta(ctypes.Structure):
# Define the JiugeWeights struct
# Define the JiugeWeights struct
class
JiugeWeights
(
ctypes
.
Structure
):
class
JiugeWeights
CStruct
(
ctypes
.
Structure
):
_fields_
=
[
_fields_
=
[
(
"nlayer"
,
c_size_t
),
(
"nlayer"
,
c_size_t
),
(
"dt_norm"
,
DataType
),
(
"dt_norm"
,
DataType
),
...
@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure):
...
@@ -72,11 +72,11 @@ class JiugeWeights(ctypes.Structure):
]
]
class
JiugeModel
(
ctypes
.
Structure
):
class
JiugeModel
CSruct
(
ctypes
.
Structure
):
pass
pass
class
KVCache
(
ctypes
.
Structure
):
class
KVCache
CStruct
(
ctypes
.
Structure
):
pass
pass
...
@@ -85,27 +85,27 @@ def __open_library__():
...
@@ -85,27 +85,27 @@ def __open_library__():
os
.
environ
.
get
(
"INFINI_ROOT"
),
"lib"
,
"libinfinicore_infer.so"
os
.
environ
.
get
(
"INFINI_ROOT"
),
"lib"
,
"libinfinicore_infer.so"
)
)
lib
=
ctypes
.
CDLL
(
lib_path
)
lib
=
ctypes
.
CDLL
(
lib_path
)
lib
.
createJiugeModel
.
restype
=
POINTER
(
JiugeModel
)
lib
.
createJiugeModel
.
restype
=
POINTER
(
JiugeModel
CSruct
)
lib
.
createJiugeModel
.
argtypes
=
[
lib
.
createJiugeModel
.
argtypes
=
[
POINTER
(
JiugeMeta
),
# JiugeMeta const *
POINTER
(
JiugeMeta
CStruct
),
# JiugeMeta const *
POINTER
(
JiugeWeights
),
# JiugeWeights const *
POINTER
(
JiugeWeights
CStruct
),
# JiugeWeights const *
DeviceType
,
# DeviceType
DeviceType
,
# DeviceType
c_int
,
# int ndev
c_int
,
# int ndev
POINTER
(
c_int
),
# int const *dev_ids
POINTER
(
c_int
),
# int const *dev_ids
]
]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
CSruct
)]
lib
.
createKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
createKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
CSruct
)]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
)
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
CStruct
)
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
),
POINTER
(
KVCache
)]
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
CSruct
),
POINTER
(
KVCache
CStruct
)]
lib
.
inferBatch
.
restype
=
None
lib
.
inferBatch
.
restype
=
None
lib
.
inferBatch
.
argtypes
=
[
lib
.
inferBatch
.
argtypes
=
[
POINTER
(
JiugeModel
),
# struct JiugeModel const *
POINTER
(
JiugeModel
CSruct
),
# struct JiugeModel const *
POINTER
(
c_uint
),
# unsigned int const *tokens
POINTER
(
c_uint
),
# unsigned int const *tokens
c_uint
,
# unsigned int ntok
c_uint
,
# unsigned int ntok
POINTER
(
c_uint
),
# unsigned int const *req_lens
POINTER
(
c_uint
),
# unsigned int const *req_lens
c_uint
,
# unsigned int nreq
c_uint
,
# unsigned int nreq
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
POINTER
(
KVCache
)),
# struct KVCache **kv_caches
POINTER
(
POINTER
(
KVCache
CStruct
)),
# struct KVCache **kv_caches
POINTER
(
c_float
),
# float temperature
POINTER
(
c_float
),
# float temperature
POINTER
(
c_uint
),
# unsigned int topk
POINTER
(
c_uint
),
# unsigned int topk
POINTER
(
c_float
),
# float topp
POINTER
(
c_float
),
# float topp
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment