Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
dcd6693f
Commit
dcd6693f
authored
Jun 24, 2025
by
Pan Zezhong
Browse files
Support continous batching
parent
21f83e91
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
308 additions
and
250 deletions
+308
-250
include/infinicore_infer/models/jiuge.h
include/infinicore_infer/models/jiuge.h
+3
-19
scripts/infer_task.py
scripts/infer_task.py
+26
-20
scripts/jiuge.py
scripts/jiuge.py
+66
-122
scripts/kvcache_pool.py
scripts/kvcache_pool.py
+48
-18
scripts/launch_server.py
scripts/launch_server.py
+141
-54
scripts/libinfinicore_infer.py
scripts/libinfinicore_infer.py
+4
-3
scripts/test_server.py
scripts/test_server.py
+5
-1
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+11
-9
src/models/jiuge/jiuge_impl.hpp
src/models/jiuge/jiuge_impl.hpp
+4
-4
No files found.
include/infinicore_infer/models/jiuge.h
View file @
dcd6693f
...
@@ -75,22 +75,6 @@ __C __export void
...
@@ -75,22 +75,6 @@ __C __export void
dropKVCache
(
const
struct
JiugeModel
*
,
dropKVCache
(
const
struct
JiugeModel
*
,
struct
KVCache
*
);
struct
KVCache
*
);
/// @brief 文本生成
/// @param tokens 输入 token
/// @param ntok 输入 token 数量
/// @param req_pos 每个请求的起始位置
/// @param output 输出 token 地址
/// @param max_step 输出 token 最大数量
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
__C
__export
void
generate
(
struct
JiugeModel
*
,
struct
KVCache
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
uint32_t
req_pos
,
uint32_t
*
output
,
uint32_t
max_step
,
float
temperature
,
uint32_t
topk
,
float
topp
);
/// @brief 批次推理一轮
/// @brief 批次推理一轮
/// @param tokens 输入 token 地址
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param ntok 输入 token 数量
...
@@ -98,16 +82,16 @@ generate(struct JiugeModel *,
...
@@ -98,16 +82,16 @@ generate(struct JiugeModel *,
/// @param req_lens 每个请求的 token 数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param kv_caches 每个请求的 KV Cache
/// @param ans 输出 token 数组,每个请求一个输出,长度至少为nreq
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C
__export
void
__C
__export
void
inferBatch
(
struct
JiugeModel
*
,
inferBatch
(
struct
JiugeModel
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
struct
KVCache
**
kv_caches
,
uint32_t
*
output
,
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
float
temperature
,
uint32_t
topk
,
float
topp
);
uint32_t
*
output
);
#endif
#endif
scripts/infer_task.py
View file @
dcd6693f
import
asyncio
import
janus
class
InferTask
:
class
InferTask
:
def
__init__
(
self
,
id
,
tokenizer
,
request
):
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
self
.
id_
=
id
self
.
id
=
id
self
.
finished_reason
=
None
self
.
finish_reason
=
None
messages
=
request
.
get
(
"messages"
,
[])
self
.
tokens
=
tokens
if
len
(
messages
)
==
0
:
self
.
max_tokens
=
max_tokens
self
.
finished_reason
=
"invalid request"
self
.
temperature
=
temperature
self
.
tokens
=
[]
self
.
topk
=
topk
else
:
self
.
topp
=
topp
input_content
=
tokenizer
.
apply_chat_template
(
self
.
end_tokens
=
end_tokens
conversation
=
messages
,
self
.
output_queue
=
janus
.
Queue
()
add_generation_prompt
=
True
,
tokenize
=
False
,
)
self
.
tokens
=
tokenizer
.
encode
(
input_content
)
self
.
request
=
request
self
.
output_queue
=
asyncio
.
Queue
()
self
.
_kv_cache_pool_item
=
None
self
.
_kv_cache_pool_item
=
None
self
.
pos
=
0
self
.
pos
=
0
print
(
f
"[INFO] Create InferTask
{
self
.
id
}
"
)
def
bind_kvcache
(
self
,
kv_cache_pool_item
,
pos
):
def
bind_kvcache
(
self
,
kv_cache_pool_item
,
pos
):
self
.
_kv_cache_pool_item
=
kv_cache_pool_item
self
.
_kv_cache_pool_item
=
kv_cache_pool_item
self
.
pos
=
pos
self
.
pos
=
pos
self
.
tokens
=
self
.
tokens
[
pos
:]
self
.
tokens
=
self
.
tokens
[
pos
:]
def
kvcache
(
self
):
def
kvcache
(
self
):
return
self
.
_kv_cache_pool_item
.
kvcache
return
self
.
_kv_cache_pool_item
.
kvcache
def
output
(
self
,
out_token
):
self
.
_kv_cache_pool_item
.
update_tokens
(
self
.
tokens
,
self
.
pos
)
self
.
pos
+=
len
(
self
.
tokens
)
if
out_token
==
None
or
out_token
in
self
.
end_tokens
:
self
.
finish_reason
=
"stop"
elif
self
.
pos
>=
self
.
max_tokens
:
self
.
finish_reason
=
"length"
else
:
self
.
tokens
=
[
out_token
]
self
.
output_queue
.
sync_q
.
put
(
out_token
)
scripts/jiuge.py
View file @
dcd6693f
from
ctypes
import
POINTER
,
c_int
,
c_uint
,
c_void_p
,
byref
from
typing
import
List
import
os
from
pathlib
import
Path
import
safetensors
import
sys
import
time
import
json
import
asyncio
from
libinfinicore_infer
import
(
from
libinfinicore_infer
import
(
JiugeMeta
,
JiugeMeta
,
JiugeWeights
,
JiugeWeights
,
...
@@ -19,6 +11,15 @@ from libinfinicore_infer import (
...
@@ -19,6 +11,15 @@ from libinfinicore_infer import (
drop_kv_cache
,
drop_kv_cache
,
infer_batch
,
infer_batch
,
)
)
from
infer_task
import
InferTask
from
ctypes
import
POINTER
,
c_float
,
c_int
,
c_uint
,
c_void_p
,
byref
import
os
from
pathlib
import
Path
import
safetensors
import
sys
import
time
import
json
import
torch
import
torch
import
transformers
import
transformers
...
@@ -286,6 +287,47 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -286,6 +287,47 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
ffn_down
=
(
c_void_p
*
nlayer
)(
*
self
.
ffn_down_ptrs
)
self
.
ffn_down
=
(
c_void_p
*
nlayer
)(
*
self
.
ffn_down_ptrs
)
class
JiugeBatchedTask
:
def
__init__
(
self
,
tasks
:
List
[
InferTask
]):
self
.
tasks
=
tasks
self
.
nreq
=
len
(
tasks
)
# Precompute fields
token_lists
=
[
t
.
tokens
for
t
in
tasks
]
self
.
req_lens_list
=
[
len
(
toks
)
for
toks
in
token_lists
]
self
.
req_pos_list
=
[
t
.
pos
for
t
in
tasks
]
self
.
kv_cache_ptrs
=
[
t
.
kvcache
()
for
t
in
tasks
]
self
.
temperaturas_list
=
[
t
.
temperature
for
t
in
tasks
]
self
.
topks_list
=
[
t
.
topk
for
t
in
tasks
]
self
.
topps_list
=
[
t
.
topp
for
t
in
tasks
]
# Flatten token lists
flat_tokens
=
[
tok
for
toks
in
token_lists
for
tok
in
toks
]
self
.
ntok
=
len
(
flat_tokens
)
# Convert to ctypes arrays in one pass
self
.
tokens
=
(
c_uint
*
self
.
ntok
)(
*
flat_tokens
)
self
.
req_lens
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_lens_list
)
self
.
req_pos
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_pos_list
)
self
.
kv_caches
=
(
POINTER
(
KVCache
)
*
self
.
nreq
)(
*
self
.
kv_cache_ptrs
)
self
.
temperaturas
=
(
c_float
*
self
.
nreq
)(
*
self
.
temperaturas_list
)
self
.
topks
=
(
c_uint
*
self
.
nreq
)(
*
self
.
topks_list
)
self
.
topps
=
(
c_float
*
self
.
nreq
)(
*
self
.
topps_list
)
def
input_args
(
self
):
return
(
self
.
tokens
,
self
.
ntok
,
self
.
req_lens
,
self
.
nreq
,
self
.
req_pos
,
self
.
kv_caches
,
self
.
temperaturas
,
self
.
topks
,
self
.
topps
,
)
class
JiugeForCauslLM
:
class
JiugeForCauslLM
:
def
__init__
(
self
,
model_dir_path
,
device
=
DeviceType
.
DEVICE_TYPE_CPU
,
ndev
=
1
):
def
__init__
(
self
,
model_dir_path
,
device
=
DeviceType
.
DEVICE_TYPE_CPU
,
ndev
=
1
):
def
load_all_safetensors_from_dir
(
dir_path_
:
str
):
def
load_all_safetensors_from_dir
(
dir_path_
:
str
):
...
@@ -382,118 +424,17 @@ class JiugeForCauslLM:
...
@@ -382,118 +424,17 @@ class JiugeForCauslLM:
def
drop_kv_cache
(
self
,
kv_cache
):
def
drop_kv_cache
(
self
,
kv_cache
):
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
def
chat
(
self
,
request
,
kv_cache
):
def
batch_infer_one_round
(
self
,
tasks
:
List
[
InferTask
]):
messages
=
request
.
get
(
"messages"
,
[])
output
=
(
c_uint
*
len
(
tasks
))()
temperature
=
request
.
get
(
"temperature"
,
1.0
)
batch_inputs
=
JiugeBatchedTask
(
tasks
)
topk
=
request
.
get
(
"top_k"
,
1
)
infer_batch
(
topp
=
request
.
get
(
"top_p"
,
1.0
)
self
.
model_instance
,
max_tokens
=
request
.
get
(
"max_tokens"
,
self
.
meta
.
dctx
)
*
(
batch_inputs
.
input_args
()),
input_content
=
self
.
tokenizer
.
apply_chat_template
(
output
,
conversation
=
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
,
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
ntok
=
len
(
tokens
)
nreq
=
1
output_content
=
""
tokens
=
(
c_uint
*
ntok
)(
*
tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
ans
=
(
c_uint
*
nreq
)()
steps
=
0
for
step_i
in
range
(
max_tokens
):
infer_batch
(
self
.
model_instance
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
ans
,
temperature
,
topk
,
topp
,
)
steps
+=
1
output_tokens
=
list
(
ans
)
output_str
=
(
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
output_content
+=
output_str
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
break
req_pos
[
0
]
=
req_pos
[
0
]
+
ntok
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
return
output_content
async
def
chat_stream_async
(
self
,
request
,
kv_cache
):
messages
=
request
.
get
(
"messages"
,
[])
temperature
=
request
.
get
(
"temperature"
,
1.0
)
topk
=
request
.
get
(
"top_k"
,
1
)
topp
=
request
.
get
(
"top_p"
,
1.0
)
max_tokens
=
request
.
get
(
"max_tokens"
,
512
)
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
,
)
)
return
list
(
output
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
def
generate
(
self
,
input_content
,
max_steps
,
topp_
=
1.0
,
topk_
=
1
,
temperature_
=
1.0
):
ntok
=
len
(
tokens
)
nreq
=
1
tokens
=
(
c_uint
*
ntok
)(
*
tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
ans
=
(
c_uint
*
nreq
)()
for
step_i
in
range
(
max_tokens
):
infer_batch
(
self
.
model_instance
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
ans
,
temperature
,
topk
,
topp
,
)
output_tokens
=
list
(
ans
)
output_str
=
(
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
yield
output_str
# Yield each token as it's produced
await
asyncio
.
sleep
(
0
)
# Let event loop breathe
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
break
req_pos
[
0
]
+=
ntok
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
def
generate
(
self
,
input_content
,
max_steps
,
topp
=
1.0
,
topk
=
1
,
temperature
=
1.0
):
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
input_content
=
self
.
tokenizer
.
apply_chat_template
(
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
...
@@ -510,6 +451,9 @@ class JiugeForCauslLM:
...
@@ -510,6 +451,9 @@ class JiugeForCauslLM:
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
ans
=
(
c_uint
*
nreq
)()
ans
=
(
c_uint
*
nreq
)()
temperature
=
(
c_float
*
nreq
)(
*
[
temperature_
])
topk
=
(
c_uint
*
nreq
)(
*
[
topk_
])
topp
=
(
c_float
*
nreq
)(
*
[
topp_
])
steps
=
0
steps
=
0
total_time
=
0
total_time
=
0
...
@@ -524,10 +468,10 @@ class JiugeForCauslLM:
...
@@ -524,10 +468,10 @@ class JiugeForCauslLM:
nreq
,
nreq
,
req_pos
,
req_pos
,
kv_caches
,
kv_caches
,
ans
,
temperature
,
temperature
,
topk
,
topk
,
topp
,
topp
,
ans
,
)
)
steps
+=
1
steps
+=
1
output_tokens
=
list
(
ans
)
output_tokens
=
list
(
ans
)
...
@@ -545,7 +489,7 @@ class JiugeForCauslLM:
...
@@ -545,7 +489,7 @@ class JiugeForCauslLM:
ntok
=
1
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
if
step_i
>
0
:
if
step_i
>
0
:
total_time
+=
end_time
-
start_time
total_time
+=
end_time
-
start_time
...
@@ -555,7 +499,7 @@ class JiugeForCauslLM:
...
@@ -555,7 +499,7 @@ class JiugeForCauslLM:
for
kv_cache
in
kv_caches
:
for
kv_cache
in
kv_caches
:
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
return
output_content
,
avg_time
return
output_content
,
avg_time
def
destroy_model_instance
(
self
):
def
destroy_model_instance
(
self
):
destroy_jiuge_model
(
self
.
model_instance
)
destroy_jiuge_model
(
self
.
model_instance
)
print
(
"Model destroyed"
)
print
(
"Model destroyed"
)
...
...
scripts/kvcache_pool.py
View file @
dcd6693f
...
@@ -10,66 +10,96 @@ class KVCachePoolItem:
...
@@ -10,66 +10,96 @@ class KVCachePoolItem:
def
drop
(
self
,
model
):
def
drop
(
self
,
model
):
model
.
drop_kv_cache
(
self
.
kvcache
)
model
.
drop_kv_cache
(
self
.
kvcache
)
def
update_tokens
(
self
,
tokens
,
pos
):
end
=
pos
+
len
(
tokens
)
max_len
=
len
(
self
.
tokens
)
# If overflow, truncate tokens to fit
if
end
>
max_len
:
tokens
=
tokens
[:
max_len
-
pos
]
end
=
max_len
self
.
tokens
[
pos
:
end
]
=
tokens
import
threading
class
KVCachePool
:
class
KVCachePool
:
def
__init__
(
self
,
model
,
max_caches
:
int
=
32
):
def
__init__
(
self
,
model
,
max_caches
:
int
=
32
):
self
.
max_caches
=
max_caches
self
.
max_caches
=
max_caches
self
.
model
=
model
self
.
model
=
model
self
.
_available
:
List
[
KVCachePoolItem
]
=
[
KVCachePoolItem
(
self
.
model
)
]
self
.
_available
:
List
[
KVCachePoolItem
]
=
[]
self
.
num_caches
=
1
self
.
num_caches
=
len
(
self
.
_available
)
self
.
_lock
=
asyncio
.
Lock
()
self
.
_lock
=
threading
.
Lock
()
self
.
_not_empty
=
asyncio
.
Condition
(
self
.
_lock
)
self
.
_not_empty
=
threading
.
Condition
(
self
.
_lock
)
self
.
_shutdown
=
False
self
.
_shutdown
=
False
async
def
acquire
(
self
,
infer_task
):
def
acquire
_sync
(
self
,
infer_task
):
async
with
self
.
_not_empty
:
with
self
.
_not_empty
:
while
True
:
while
True
:
if
self
.
_shutdown
:
if
self
.
_shutdown
:
raise
RuntimeError
(
"KVCachePool is shutting down; cannot acquire new cache."
)
raise
RuntimeError
(
"KVCachePool is shutting down; cannot acquire new cache."
)
if
len
(
self
.
_available
)
==
0
:
if
len
(
self
.
_available
)
==
0
:
if
self
.
num_caches
<
self
.
max_caches
:
if
self
.
num_caches
<
self
.
max_caches
:
self
.
num_caches
+=
1
self
.
num_caches
+=
1
print
(
f
"[INFO] Task
{
infer_task
.
id
}
created new KVCachePoolItem"
)
return
infer_task
.
bind_kvcache
(
KVCachePoolItem
(
self
.
model
),
0
)
return
infer_task
.
bind_kvcache
(
KVCachePoolItem
(
self
.
model
),
0
)
else
:
else
:
await
self
.
_not_empty
.
wait
()
self
.
_not_empty
.
wait
()
else
:
else
:
max_match
,
max_match_index
=
self
.
find_most_matching_cache
(
max_match
,
max_match_index
=
self
.
find_most_matching_cache
(
infer_task
.
tokens
infer_task
.
tokens
)
)
kvcache
=
self
.
_available
.
pop
(
max_match_index
)
kvcache
=
self
.
_available
.
pop
(
max_match_index
)
print
(
f
"[INFO] Task
{
infer_task
.
id
}
reused KVCachePoolItem
{
max_match_index
}
with
{
max_match
}
matches"
)
return
infer_task
.
bind_kvcache
(
kvcache
,
max_match
)
return
infer_task
.
bind_kvcache
(
kvcache
,
max_match
)
async
def
release
(
self
,
infer_task
):
def
release_sync
(
self
,
infer_task
):
async
with
self
.
_not_empty
:
with
self
.
_not_empty
:
print
(
f
"[INFO] Task
{
infer_task
.
id
}
returned KVCachePoolItem to pool"
)
self
.
_available
.
append
(
infer_task
.
_kv_cache_pool_item
)
self
.
_available
.
append
(
infer_task
.
_kv_cache_pool_item
)
infer_task
.
_kv_cache_pool_item
=
None
self
.
_not_empty
.
notify
()
self
.
_not_empty
.
notify
()
async
def
acquire
(
self
,
infer_task
):
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
None
,
self
.
acquire_sync
,
infer_task
)
async
def
release
(
self
,
infer_task
):
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
None
,
self
.
release_sync
,
infer_task
)
def
find_most_matching_cache
(
self
,
tokens
:
List
[
int
]):
def
find_most_matching_cache
(
self
,
tokens
:
List
[
int
]):
max_match
=
0
max_match
=
0
max_match_index
=
0
max_match_index
=
0
def
first_different_index
(
a_
,
b_
):
def
first_different_index
(
a_
,
b_
):
for
i_
,
(
x_
,
y_
)
in
enumerate
(
zip
(
a_
,
b_
)):
for
i_
,
(
x_
,
y_
)
in
enumerate
(
zip
(
a_
,
b_
)):
if
x_
!=
y_
:
if
x_
!=
y_
:
return
i_
return
i_
return
min
(
len
(
a_
),
len
(
b_
))
return
min
(
len
(
a_
),
len
(
b_
))
for
i
,
kvcache
in
enumerate
(
self
.
_available
):
for
i
,
kvcache
in
enumerate
(
self
.
_available
):
common_elements
=
first_different_index
(
tokens
,
kvcache
.
tokens
)
common_elements
=
first_different_index
(
tokens
,
kvcache
.
tokens
)
# print(f"{tokens}")
# print(f"{kvcache.tokens[:len(tokens)]}")
if
common_elements
>
max_match
:
if
common_elements
>
max_match
:
max_match
=
common_elements
max_match
=
common_elements
max_match_index
=
i
max_match_index
=
i
# max match should always be less then input tokens length
return
(
min
(
max_match
,
len
(
tokens
)
-
1
),
max_match_index
)
return
(
min
(
max_match
,
len
(
tokens
)
-
1
),
max_match_index
)
async
def
finalize
(
self
):
def
finalize
(
self
):
async
with
self
.
_not_empty
:
with
self
.
_not_empty
:
self
.
_shutdown
=
True
self
.
_shutdown
=
True
while
len
(
self
.
_available
)
<
self
.
num_caches
:
while
len
(
self
.
_available
)
<
self
.
num_caches
:
await
self
.
_not_empty
.
wait
()
self
.
_not_empty
.
wait
()
# All caches are now available
for
kvcache
in
self
.
_available
:
for
kvcache
in
self
.
_available
:
if
kvcache
is
not
None
:
if
kvcache
is
not
None
:
kvcache
.
drop
(
self
.
model
)
kvcache
.
drop
(
self
.
model
)
...
...
scripts/launch_server.py
View file @
dcd6693f
import
asyncio
from
jiuge
import
JiugeForCauslLM
from
jiuge
import
JiugeForCauslLM
from
libinfinicore_infer
import
DeviceType
from
libinfinicore_infer
import
DeviceType
from
infer_task
import
InferTask
from
infer_task
import
InferTask
from
kvcache_pool
import
KVCachePool
from
kvcache_pool
import
KVCachePool
import
queue
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
import
anyio
import
contextlib
import
uvicorn
import
uvicorn
import
time
import
time
import
uuid
import
uuid
import
sys
import
sys
import
signal
import
json
import
json
import
threading
import
janus
if
len
(
sys
.
argv
)
<
3
:
if
len
(
sys
.
argv
)
<
3
:
print
(
print
(
...
@@ -40,26 +41,6 @@ else:
...
@@ -40,26 +41,6 @@ else:
sys
.
exit
(
1
)
sys
.
exit
(
1
)
ndev
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
1
ndev
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
1
MODEL
=
JiugeForCauslLM
(
model_path
,
device_type
,
ndev
)
App
=
FastAPI
()
@
App
.
on_event
(
"startup"
)
async
def
setup
():
App
.
state
.
kv_cache_pool
=
KVCachePool
(
MODEL
,
1
)
async
def
handle_shutdown
():
await
App
.
state
.
kv_cache_pool
.
finalize
()
MODEL
.
destroy_model_instance
()
sys
.
exit
(
0
)
def
signal_handler
(
sig
,
frame
):
print
(
f
"Received signal
{
sig
}
, cleaning up..."
)
asyncio
.
create_task
(
handle_shutdown
())
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
# Handle Ctrl+C
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
# Handle docker stop / system shutdown
def
chunk_json
(
id_
,
content
=
None
,
role
=
None
,
finish_reason
=
None
):
def
chunk_json
(
id_
,
content
=
None
,
role
=
None
,
finish_reason
=
None
):
delta
=
{}
delta
=
{}
...
@@ -84,50 +65,156 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
...
@@ -84,50 +65,156 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
}
}
MAX_BATCH
=
3
print
(
f
"Using MAX_BATCH=
{
MAX_BATCH
}
. Try reduce this value if out of memory error occurs."
)
@
contextlib
.
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
# Startup
app
.
state
.
model
=
JiugeForCauslLM
(
model_path
,
device_type
,
ndev
)
app
.
state
.
kv_cache_pool
=
KVCachePool
(
app
.
state
.
model
,
MAX_BATCH
)
app
.
state
.
request_queue
=
janus
.
Queue
()
worker_thread
=
threading
.
Thread
(
target
=
worker_loop
,
args
=
(
app
,),
daemon
=
True
)
worker_thread
.
start
()
try
:
yield
# The app runs here
finally
:
# Shutdown
app
.
state
.
request_queue
.
sync_q
.
put
(
None
)
worker_thread
.
join
()
app
.
state
.
request_queue
.
shutdown
()
app
.
state
.
kv_cache_pool
.
finalize
()
app
.
state
.
model
.
destroy_model_instance
()
App
=
FastAPI
(
lifespan
=
lifespan
)
# App loop: take requests from the queue, do inference, and put unfinished requests back into the queue.
def
worker_loop
(
app
):
while
True
:
try
:
task
=
app
.
state
.
request_queue
.
sync_q
.
get
(
timeout
=
0.01
)
except
queue
.
Empty
:
continue
if
task
is
None
:
return
batch
=
[
task
]
while
len
(
batch
)
<
MAX_BATCH
:
try
:
req
=
app
.
state
.
request_queue
.
sync_q
.
get_nowait
()
if
req
is
not
None
:
batch
.
append
(
req
)
except
queue
.
Empty
:
break
output_tokens
=
app
.
state
.
model
.
batch_infer_one_round
(
batch
)
for
task
,
token
in
zip
(
batch
,
output_tokens
):
task
.
output
(
token
)
if
task
.
finish_reason
is
None
:
app
.
state
.
request_queue
.
sync_q
.
put
(
task
)
else
:
print
(
f
"[INFO] Task
{
task
.
id
}
finished infer."
)
app
.
state
.
kv_cache_pool
.
release_sync
(
task
)
def
build_task
(
id_
,
request_data
,
request
:
Request
):
messages
=
request_data
.
get
(
"messages"
,
[])
input_content
=
request
.
app
.
state
.
model
.
tokenizer
.
apply_chat_template
(
conversation
=
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
,
)
tokens
=
request
.
app
.
state
.
model
.
tokenizer
.
encode
(
input_content
)
return
InferTask
(
id_
,
tokens
,
request_data
.
get
(
"max_tokens"
,
request
.
app
.
state
.
model
.
max_context_len
()),
request_data
.
get
(
"temperature"
,
1.0
),
request_data
.
get
(
"top_k"
,
1
),
request_data
.
get
(
"top_p"
,
1.0
),
request
.
app
.
state
.
model
.
eos_token_id
,
)
async
def
chat_stream
(
id_
,
request_data
,
request
:
Request
):
async
def
chat_stream
(
id_
,
request_data
,
request
:
Request
):
try
:
try
:
infer_task
=
InferTask
(
id_
,
MODEL
.
tokenizer
,
request_data
)
infer_task
=
build_task
(
id_
,
request_data
,
request
)
await
App
.
state
.
kv_cache_pool
.
acquire
(
infer_task
)
await
request
.
app
.
state
.
kv_cache_pool
.
acquire
(
infer_task
)
# Initial empty content
chunk
=
json
.
dumps
(
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
content
=
""
,
role
=
"assistant"
),
chunk_json
(
id_
,
content
=
""
,
role
=
"assistant"
),
ensure_ascii
=
False
ensure_ascii
=
False
,
)
)
yield
f
"
{
chunk
}
\n\n
"
yield
f
"
{
chunk
}
\n\n
"
async
for
token
in
MODEL
.
chat_stream_async
(
request
.
app
.
state
.
request_queue
.
sync_q
.
put
(
infer_task
)
infer_task
.
request
,
infer_task
.
kvcache
(),
while
True
:
):
if
await
request
.
is_disconnected
():
if
await
request
.
is_disconnected
():
print
(
"Client disconnected. Aborting stream."
)
print
(
"Client disconnected. Aborting stream."
)
break
break
chunk
=
json
.
dumps
(
if
(
chunk_json
(
id_
,
content
=
token
),
infer_task
.
finish_reason
is
not
None
ensure_ascii
=
False
,
and
infer_task
.
output_queue
.
async_q
.
empty
()
):
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
finish_reason
=
infer_task
.
finish_reason
),
ensure_ascii
=
False
,
)
yield
f
"
{
chunk
}
\n\n
"
break
token
=
await
infer_task
.
output_queue
.
async_q
.
get
()
content
=
(
request
.
app
.
state
.
model
.
tokenizer
.
_tokenizer
.
id_to_token
(
token
)
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
)
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
content
=
content
),
ensure_ascii
=
False
)
yield
f
"
{
chunk
}
\n\n
"
yield
f
"
{
chunk
}
\n\n
"
finally
:
await
App
.
state
.
kv_cache_pool
.
release
(
infer_task
)
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
finish_reason
=
"stop"
),
ensure_ascii
=
False
,
)
yield
f
"
{
chunk
}
\n\n
"
except
Exception
as
e
:
print
(
f
"[Error] ID :
{
id_
}
Exception:
{
e
}
"
)
async
def
chat
(
id_
,
request_data
):
infer_task
=
InferTask
(
id_
,
MODEL
.
tokenizer
,
request_data
)
await
App
.
state
.
kv_cache_pool
.
acquire
(
infer_task
)
output_text
=
MODEL
.
chat
(
async
def
chat
(
id_
,
request_data
,
request
:
Request
):
infer_task
.
request
,
try
:
infer_task
.
kvcache
(),
infer_task
=
build_task
(
id_
,
request_data
,
request
)
)
await
request
.
app
.
state
.
kv_cache_pool
.
acquire
(
infer_task
)
response
=
chunk_json
(
request
.
app
.
state
.
request_queue
.
sync_q
.
put
(
infer_task
)
id_
,
content
=
output_text
.
strip
(),
role
=
"assistant"
,
finish_reason
=
"stop"
output
=
[]
)
while
True
:
await
App
.
state
.
kv_cache_pool
.
release
(
infer_task
)
if
(
return
JSONResponse
(
response
)
infer_task
.
finish_reason
is
not
None
and
infer_task
.
output_queue
.
async_q
.
empty
()
):
break
token
=
await
infer_task
.
output_queue
.
async_q
.
get
()
content
=
(
request
.
app
.
state
.
model
.
tokenizer
.
_tokenizer
.
id_to_token
(
token
)
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
output
.
append
(
content
)
output_text
=
""
.
join
(
output
).
strip
()
response
=
chunk_json
(
id_
,
content
=
output_text
,
role
=
"assistant"
,
finish_reason
=
infer_task
.
finish_reason
or
"stop"
,
)
return
response
except
Exception
as
e
:
print
(
f
"[Error] ID:
{
id_
}
Exception:
{
e
}
"
)
return
JSONResponse
(
content
=
{
"error"
:
str
(
e
)},
status_code
=
500
)
@
App
.
post
(
"/chat/completions"
)
@
App
.
post
(
"/chat/completions"
)
...
@@ -144,7 +231,7 @@ async def chat_completions(request: Request):
...
@@ -144,7 +231,7 @@ async def chat_completions(request: Request):
chat_stream
(
id_
,
data
,
request
),
media_type
=
"text/event-stream"
chat_stream
(
id_
,
data
,
request
),
media_type
=
"text/event-stream"
)
)
else
:
else
:
return
chat
(
id_
,
data
)
return
JSONResponse
(
chat
(
id_
,
data
)
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
scripts/libinfinicore_infer.py
View file @
dcd6693f
...
@@ -94,6 +94,7 @@ def __open_library__():
...
@@ -94,6 +94,7 @@ def __open_library__():
POINTER
(
c_int
),
# int const *dev_ids
POINTER
(
c_int
),
# int const *dev_ids
]
]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
createKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
)
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
)
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
),
POINTER
(
KVCache
)]
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
),
POINTER
(
KVCache
)]
lib
.
inferBatch
.
restype
=
None
lib
.
inferBatch
.
restype
=
None
...
@@ -105,10 +106,10 @@ def __open_library__():
...
@@ -105,10 +106,10 @@ def __open_library__():
c_uint
,
# unsigned int nreq
c_uint
,
# unsigned int nreq
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
POINTER
(
KVCache
)),
# struct KVCache **kv_caches
POINTER
(
POINTER
(
KVCache
)),
# struct KVCache **kv_caches
POINTER
(
c_float
),
# float temperature
POINTER
(
c_uint
),
# unsigned int topk
POINTER
(
c_float
),
# float topp
POINTER
(
c_uint
),
# unsigned int *output
POINTER
(
c_uint
),
# unsigned int *output
c_float
,
# float temperature
c_uint
,
# unsigned int topk
c_float
,
# float topp
]
]
return
lib
return
lib
...
...
scripts/test_server.py
View file @
dcd6693f
...
@@ -5,13 +5,14 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
...
@@ -5,13 +5,14 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
API_URL
=
"http://localhost:8000/chat/completions"
API_URL
=
"http://localhost:8000/chat/completions"
MODEL
=
"FM9G-7B"
MODEL
=
"FM9G-7B"
PROMPT
=
[
"
给我讲个故事"
,
"山东最高的山是?
"
]
PROMPT
=
[
"
山东最高的山是?"
,
"给我讲个故事
"
]
CONCURRENCY
=
10
# 并发用户数量
CONCURRENCY
=
10
# 并发用户数量
def
single_run
(
user_id
):
def
single_run
(
user_id
):
payload
=
{
payload
=
{
"model"
:
MODEL
,
"model"
:
MODEL
,
"messages"
:
[{
"role"
:
"user"
,
"content"
:
PROMPT
[
user_id
%
len
(
PROMPT
)]}],
"messages"
:
[{
"role"
:
"user"
,
"content"
:
PROMPT
[
user_id
%
len
(
PROMPT
)]}],
"max_tokens"
:
512
,
"stream"
:
True
"stream"
:
True
}
}
headers
=
{
'Content-Type'
:
'application/json'
,
'Accept'
:
'application/json'
}
headers
=
{
'Content-Type'
:
'application/json'
,
'Accept'
:
'application/json'
}
...
@@ -86,6 +87,9 @@ def main():
...
@@ -86,6 +87,9 @@ def main():
if
r
[
'stream_time'
]
<
best_stream
:
if
r
[
'stream_time'
]
<
best_stream
:
best_stream
=
r
[
'stream_time'
]
best_stream
=
r
[
'stream_time'
]
best
=
r
best
=
r
# Sort results by user ID
results
.
sort
(
key
=
lambda
x
:
x
[
"user"
])
with
open
(
"responses.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
fw
:
with
open
(
"responses.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
fw
:
for
r
in
results
:
for
r
in
results
:
...
...
src/models/jiuge/jiuge.cpp
View file @
dcd6693f
...
@@ -115,8 +115,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -115,8 +115,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
struct
KVCache
**
kv_caches
,
uint32_t
*
ans
,
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
float
temperature
,
uint32_t
topk
,
float
topp
)
{
uint32_t
*
output
)
{
auto
nlayer
=
meta
.
nlayer
;
auto
nlayer
=
meta
.
nlayer
;
auto
nkvh
=
meta
.
nkvh
/
ndev
;
auto
nkvh
=
meta
.
nkvh
/
ndev
;
auto
nh
=
meta
.
nh
/
ndev
;
auto
nh
=
meta
.
nh
/
ndev
;
...
@@ -457,8 +457,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -457,8 +457,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI
(
infiniopRandomSample
(
RUN_INFINI
(
infiniopRandomSample
(
desc_sample
,
workspace
,
workspace_size
,
desc_sample
,
workspace
,
workspace_size
,
result_buf
->
data
(
req
),
result_buf
->
data
(
req
),
prob_buf
->
data
(
req
*
dvoc
),
random_val
,
topp
,
prob_buf
->
data
(
req
*
dvoc
),
topk
,
temperature
,
stream
));
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
],
stream
));
// result_buf->debug();
// result_buf->debug();
token_offset
+=
seq_len
;
token_offset
+=
seq_len
;
}
}
...
@@ -466,7 +468,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -466,7 +468,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI
(
infinirtMemcpy
(
result_cpu
.
data
(),
result_buf
->
data
(),
RUN_INFINI
(
infinirtMemcpy
(
result_cpu
.
data
(),
result_buf
->
data
(),
sizeof
(
int64_t
)
*
nreq
,
INFINIRT_MEMCPY_D2H
));
sizeof
(
int64_t
)
*
nreq
,
INFINIRT_MEMCPY_D2H
));
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
ans
[
req
]
=
result_cpu
[
req
];
output
[
req
]
=
result_cpu
[
req
];
}
}
}
}
...
@@ -500,15 +502,15 @@ inferBatch(struct JiugeModel *model,
...
@@ -500,15 +502,15 @@ inferBatch(struct JiugeModel *model,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
struct
KVCache
**
kv_caches
,
uint32_t
*
ans
,
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
float
temperature
,
uint32_t
topk
,
float
topp
)
{
uint32_t
*
output
)
{
model
->
req
.
tokens
=
tokens
;
model
->
req
.
tokens
=
tokens
;
model
->
req
.
ntok
=
ntok
;
model
->
req
.
ntok
=
ntok
;
model
->
req
.
req_lens
=
req_lens
;
model
->
req
.
req_lens
=
req_lens
;
model
->
req
.
nreq
=
nreq
;
model
->
req
.
nreq
=
nreq
;
model
->
req
.
req_pos
=
req_pos
;
model
->
req
.
req_pos
=
req_pos
;
model
->
req
.
kv_caches
=
kv_caches
;
model
->
req
.
kv_caches
=
kv_caches
;
model
->
req
.
ans
=
ans
;
model
->
req
.
output
=
output
;
model
->
req
.
temperature
=
temperature
;
model
->
req
.
temperature
=
temperature
;
model
->
req
.
topk
=
topk
;
model
->
req
.
topk
=
topk
;
model
->
req
.
topp
=
topp
;
model
->
req
.
topp
=
topp
;
...
@@ -547,7 +549,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
...
@@ -547,7 +549,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
break
;
break
;
}
}
inferDeviceBatch
(
meta
,
*
rsrc
,
idev
,
ndev
,
req
.
tokens
,
req
.
ntok
,
req
.
req_lens
,
req
.
nreq
,
req
.
req_pos
,
req
.
kv_caches
,
req
.
ans
,
req
.
temperature
,
req
.
topk
,
req
.
topp
);
inferDeviceBatch
(
meta
,
*
rsrc
,
idev
,
ndev
,
req
.
tokens
,
req
.
ntok
,
req
.
req_lens
,
req
.
nreq
,
req
.
req_pos
,
req
.
kv_caches
,
req
.
temperature
,
req
.
topk
,
req
.
topp
,
req
.
output
);
state
.
proceed
=
false
;
state
.
proceed
=
false
;
lock
.
unlock
();
lock
.
unlock
();
...
...
src/models/jiuge/jiuge_impl.hpp
View file @
dcd6693f
...
@@ -45,10 +45,10 @@ struct InferRequest {
...
@@ -45,10 +45,10 @@ struct InferRequest {
uint32_t
nreq
;
uint32_t
nreq
;
const
uint32_t
*
req_pos
;
const
uint32_t
*
req_pos
;
struct
KVCache
**
kv_caches
;
struct
KVCache
**
kv_caches
;
uint32_t
*
ans
;
const
float
*
temperature
;
float
temperature
;
const
uint32_t
*
topk
;
uint32_
t
top
k
;
const
floa
t
*
top
p
;
float
topp
;
uint32_t
*
output
;
};
};
struct
JiugeModel
{
struct
JiugeModel
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment