Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
dcd6693f
Commit
dcd6693f
authored
Jun 24, 2025
by
Pan Zezhong
Browse files
Support continous batching
parent
21f83e91
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
308 additions
and
250 deletions
+308
-250
include/infinicore_infer/models/jiuge.h
include/infinicore_infer/models/jiuge.h
+3
-19
scripts/infer_task.py
scripts/infer_task.py
+26
-20
scripts/jiuge.py
scripts/jiuge.py
+66
-122
scripts/kvcache_pool.py
scripts/kvcache_pool.py
+48
-18
scripts/launch_server.py
scripts/launch_server.py
+141
-54
scripts/libinfinicore_infer.py
scripts/libinfinicore_infer.py
+4
-3
scripts/test_server.py
scripts/test_server.py
+5
-1
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+11
-9
src/models/jiuge/jiuge_impl.hpp
src/models/jiuge/jiuge_impl.hpp
+4
-4
No files found.
include/infinicore_infer/models/jiuge.h
View file @
dcd6693f
...
...
@@ -75,22 +75,6 @@ __C __export void
dropKVCache
(
const
struct
JiugeModel
*
,
struct
KVCache
*
);
/// @brief 文本生成
/// @param tokens 输入 token
/// @param ntok 输入 token 数量
/// @param req_pos 每个请求的起始位置
/// @param output 输出 token 地址
/// @param max_step 输出 token 最大数量
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
__C
__export
void
generate
(
struct
JiugeModel
*
,
struct
KVCache
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
uint32_t
req_pos
,
uint32_t
*
output
,
uint32_t
max_step
,
float
temperature
,
uint32_t
topk
,
float
topp
);
/// @brief 批次推理一轮
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
...
...
@@ -98,16 +82,16 @@ generate(struct JiugeModel *,
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param ans 输出 token 数组,每个请求一个输出,长度至少为nreq
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C
__export
void
inferBatch
(
struct
JiugeModel
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
uint32_t
*
output
,
float
temperature
,
uint32_t
topk
,
float
topp
);
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
uint32_t
*
output
);
#endif
scripts/infer_task.py
View file @
dcd6693f
import
asyncio
import
janus
class
InferTask
:
def
__init__
(
self
,
id
,
tokenizer
,
request
):
self
.
id_
=
id
self
.
finished_reason
=
None
messages
=
request
.
get
(
"messages"
,
[])
if
len
(
messages
)
==
0
:
self
.
finished_reason
=
"invalid request"
self
.
tokens
=
[]
else
:
input_content
=
tokenizer
.
apply_chat_template
(
conversation
=
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
,
)
self
.
tokens
=
tokenizer
.
encode
(
input_content
)
self
.
request
=
request
self
.
output_queue
=
asyncio
.
Queue
()
def
__init__
(
self
,
id
,
tokens
,
max_tokens
,
temperature
,
topk
,
topp
,
end_tokens
):
self
.
id
=
id
self
.
finish_reason
=
None
self
.
tokens
=
tokens
self
.
max_tokens
=
max_tokens
self
.
temperature
=
temperature
self
.
topk
=
topk
self
.
topp
=
topp
self
.
end_tokens
=
end_tokens
self
.
output_queue
=
janus
.
Queue
()
self
.
_kv_cache_pool_item
=
None
self
.
pos
=
0
print
(
f
"[INFO] Create InferTask
{
self
.
id
}
"
)
def
bind_kvcache
(
self
,
kv_cache_pool_item
,
pos
):
self
.
_kv_cache_pool_item
=
kv_cache_pool_item
...
...
@@ -30,3 +24,15 @@ class InferTask:
def
kvcache
(
self
):
return
self
.
_kv_cache_pool_item
.
kvcache
def
output
(
self
,
out_token
):
self
.
_kv_cache_pool_item
.
update_tokens
(
self
.
tokens
,
self
.
pos
)
self
.
pos
+=
len
(
self
.
tokens
)
if
out_token
==
None
or
out_token
in
self
.
end_tokens
:
self
.
finish_reason
=
"stop"
elif
self
.
pos
>=
self
.
max_tokens
:
self
.
finish_reason
=
"length"
else
:
self
.
tokens
=
[
out_token
]
self
.
output_queue
.
sync_q
.
put
(
out_token
)
scripts/jiuge.py
View file @
dcd6693f
from
ctypes
import
POINTER
,
c_int
,
c_uint
,
c_void_p
,
byref
import
os
from
pathlib
import
Path
import
safetensors
import
sys
import
time
import
json
import
asyncio
from
typing
import
List
from
libinfinicore_infer
import
(
JiugeMeta
,
JiugeWeights
,
...
...
@@ -19,6 +11,15 @@ from libinfinicore_infer import (
drop_kv_cache
,
infer_batch
,
)
from
infer_task
import
InferTask
from
ctypes
import
POINTER
,
c_float
,
c_int
,
c_uint
,
c_void_p
,
byref
import
os
from
pathlib
import
Path
import
safetensors
import
sys
import
time
import
json
import
torch
import
transformers
...
...
@@ -286,6 +287,47 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
ffn_down
=
(
c_void_p
*
nlayer
)(
*
self
.
ffn_down_ptrs
)
class
JiugeBatchedTask
:
def
__init__
(
self
,
tasks
:
List
[
InferTask
]):
self
.
tasks
=
tasks
self
.
nreq
=
len
(
tasks
)
# Precompute fields
token_lists
=
[
t
.
tokens
for
t
in
tasks
]
self
.
req_lens_list
=
[
len
(
toks
)
for
toks
in
token_lists
]
self
.
req_pos_list
=
[
t
.
pos
for
t
in
tasks
]
self
.
kv_cache_ptrs
=
[
t
.
kvcache
()
for
t
in
tasks
]
self
.
temperaturas_list
=
[
t
.
temperature
for
t
in
tasks
]
self
.
topks_list
=
[
t
.
topk
for
t
in
tasks
]
self
.
topps_list
=
[
t
.
topp
for
t
in
tasks
]
# Flatten token lists
flat_tokens
=
[
tok
for
toks
in
token_lists
for
tok
in
toks
]
self
.
ntok
=
len
(
flat_tokens
)
# Convert to ctypes arrays in one pass
self
.
tokens
=
(
c_uint
*
self
.
ntok
)(
*
flat_tokens
)
self
.
req_lens
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_lens_list
)
self
.
req_pos
=
(
c_uint
*
self
.
nreq
)(
*
self
.
req_pos_list
)
self
.
kv_caches
=
(
POINTER
(
KVCache
)
*
self
.
nreq
)(
*
self
.
kv_cache_ptrs
)
self
.
temperaturas
=
(
c_float
*
self
.
nreq
)(
*
self
.
temperaturas_list
)
self
.
topks
=
(
c_uint
*
self
.
nreq
)(
*
self
.
topks_list
)
self
.
topps
=
(
c_float
*
self
.
nreq
)(
*
self
.
topps_list
)
def
input_args
(
self
):
return
(
self
.
tokens
,
self
.
ntok
,
self
.
req_lens
,
self
.
nreq
,
self
.
req_pos
,
self
.
kv_caches
,
self
.
temperaturas
,
self
.
topks
,
self
.
topps
,
)
class
JiugeForCauslLM
:
def
__init__
(
self
,
model_dir_path
,
device
=
DeviceType
.
DEVICE_TYPE_CPU
,
ndev
=
1
):
def
load_all_safetensors_from_dir
(
dir_path_
:
str
):
...
...
@@ -382,118 +424,17 @@ class JiugeForCauslLM:
def
drop_kv_cache
(
self
,
kv_cache
):
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
def
chat
(
self
,
request
,
kv_cache
):
messages
=
request
.
get
(
"messages"
,
[])
temperature
=
request
.
get
(
"temperature"
,
1.0
)
topk
=
request
.
get
(
"top_k"
,
1
)
topp
=
request
.
get
(
"top_p"
,
1.0
)
max_tokens
=
request
.
get
(
"max_tokens"
,
self
.
meta
.
dctx
)
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
,
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
ntok
=
len
(
tokens
)
nreq
=
1
output_content
=
""
tokens
=
(
c_uint
*
ntok
)(
*
tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
ans
=
(
c_uint
*
nreq
)()
steps
=
0
for
step_i
in
range
(
max_tokens
):
def
batch_infer_one_round
(
self
,
tasks
:
List
[
InferTask
]):
output
=
(
c_uint
*
len
(
tasks
))()
batch_inputs
=
JiugeBatchedTask
(
tasks
)
infer_batch
(
self
.
model_instance
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
ans
,
temperature
,
topk
,
topp
,
*
(
batch_inputs
.
input_args
()),
output
,
)
steps
+=
1
output_tokens
=
list
(
ans
)
output_str
=
(
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
output_content
+=
output_str
return
list
(
output
)
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
break
req_pos
[
0
]
=
req_pos
[
0
]
+
ntok
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
return
output_content
async
def
chat_stream_async
(
self
,
request
,
kv_cache
):
messages
=
request
.
get
(
"messages"
,
[])
temperature
=
request
.
get
(
"temperature"
,
1.0
)
topk
=
request
.
get
(
"top_k"
,
1
)
topp
=
request
.
get
(
"top_p"
,
1.0
)
max_tokens
=
request
.
get
(
"max_tokens"
,
512
)
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
,
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
ntok
=
len
(
tokens
)
nreq
=
1
tokens
=
(
c_uint
*
ntok
)(
*
tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
ans
=
(
c_uint
*
nreq
)()
for
step_i
in
range
(
max_tokens
):
infer_batch
(
self
.
model_instance
,
tokens
,
ntok
,
req_lens
,
nreq
,
req_pos
,
kv_caches
,
ans
,
temperature
,
topk
,
topp
,
)
output_tokens
=
list
(
ans
)
output_str
=
(
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
yield
output_str
# Yield each token as it's produced
await
asyncio
.
sleep
(
0
)
# Let event loop breathe
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
break
req_pos
[
0
]
+=
ntok
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
def
generate
(
self
,
input_content
,
max_steps
,
topp
=
1.0
,
topk
=
1
,
temperature
=
1.0
):
def
generate
(
self
,
input_content
,
max_steps
,
topp_
=
1.0
,
topk_
=
1
,
temperature_
=
1.0
):
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
...
...
@@ -510,6 +451,9 @@ class JiugeForCauslLM:
req_pos
=
(
c_uint
*
nreq
)(
*
[
0
])
kv_caches
=
(
POINTER
(
KVCache
)
*
nreq
)(
*
[
kv_cache
])
ans
=
(
c_uint
*
nreq
)()
temperature
=
(
c_float
*
nreq
)(
*
[
temperature_
])
topk
=
(
c_uint
*
nreq
)(
*
[
topk_
])
topp
=
(
c_float
*
nreq
)(
*
[
topp_
])
steps
=
0
total_time
=
0
...
...
@@ -524,10 +468,10 @@ class JiugeForCauslLM:
nreq
,
req_pos
,
kv_caches
,
ans
,
temperature
,
topk
,
topp
,
ans
,
)
steps
+=
1
output_tokens
=
list
(
ans
)
...
...
scripts/kvcache_pool.py
View file @
dcd6693f
...
...
@@ -10,40 +10,70 @@ class KVCachePoolItem:
def
drop
(
self
,
model
):
model
.
drop_kv_cache
(
self
.
kvcache
)
def
update_tokens
(
self
,
tokens
,
pos
):
end
=
pos
+
len
(
tokens
)
max_len
=
len
(
self
.
tokens
)
# If overflow, truncate tokens to fit
if
end
>
max_len
:
tokens
=
tokens
[:
max_len
-
pos
]
end
=
max_len
self
.
tokens
[
pos
:
end
]
=
tokens
import
threading
class
KVCachePool
:
def
__init__
(
self
,
model
,
max_caches
:
int
=
32
):
self
.
max_caches
=
max_caches
self
.
model
=
model
self
.
_available
:
List
[
KVCachePoolItem
]
=
[
KVCachePoolItem
(
self
.
model
)
]
self
.
num_caches
=
1
self
.
_lock
=
asyncio
.
Lock
()
self
.
_not_empty
=
asyncio
.
Condition
(
self
.
_lock
)
self
.
_available
:
List
[
KVCachePoolItem
]
=
[]
self
.
num_caches
=
len
(
self
.
_available
)
self
.
_lock
=
threading
.
Lock
()
self
.
_not_empty
=
threading
.
Condition
(
self
.
_lock
)
self
.
_shutdown
=
False
async
def
acquire
(
self
,
infer_task
):
async
with
self
.
_not_empty
:
def
acquire
_sync
(
self
,
infer_task
):
with
self
.
_not_empty
:
while
True
:
if
self
.
_shutdown
:
raise
RuntimeError
(
"KVCachePool is shutting down; cannot acquire new cache."
)
raise
RuntimeError
(
"KVCachePool is shutting down; cannot acquire new cache."
)
if
len
(
self
.
_available
)
==
0
:
if
self
.
num_caches
<
self
.
max_caches
:
self
.
num_caches
+=
1
print
(
f
"[INFO] Task
{
infer_task
.
id
}
created new KVCachePoolItem"
)
return
infer_task
.
bind_kvcache
(
KVCachePoolItem
(
self
.
model
),
0
)
else
:
await
self
.
_not_empty
.
wait
()
self
.
_not_empty
.
wait
()
else
:
max_match
,
max_match_index
=
self
.
find_most_matching_cache
(
infer_task
.
tokens
)
kvcache
=
self
.
_available
.
pop
(
max_match_index
)
print
(
f
"[INFO] Task
{
infer_task
.
id
}
reused KVCachePoolItem
{
max_match_index
}
with
{
max_match
}
matches"
)
return
infer_task
.
bind_kvcache
(
kvcache
,
max_match
)
async
def
release
(
self
,
infer_task
):
async
with
self
.
_not_empty
:
def
release_sync
(
self
,
infer_task
):
with
self
.
_not_empty
:
print
(
f
"[INFO] Task
{
infer_task
.
id
}
returned KVCachePoolItem to pool"
)
self
.
_available
.
append
(
infer_task
.
_kv_cache_pool_item
)
infer_task
.
_kv_cache_pool_item
=
None
self
.
_not_empty
.
notify
()
async
def
acquire
(
self
,
infer_task
):
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
None
,
self
.
acquire_sync
,
infer_task
)
async
def
release
(
self
,
infer_task
):
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
None
,
self
.
release_sync
,
infer_task
)
def
find_most_matching_cache
(
self
,
tokens
:
List
[
int
]):
max_match
=
0
max_match_index
=
0
...
...
@@ -56,20 +86,20 @@ class KVCachePool:
for
i
,
kvcache
in
enumerate
(
self
.
_available
):
common_elements
=
first_different_index
(
tokens
,
kvcache
.
tokens
)
# print(f"{tokens}")
# print(f"{kvcache.tokens[:len(tokens)]}")
if
common_elements
>
max_match
:
max_match
=
common_elements
max_match_index
=
i
# max match should always be less then input tokens length
return
(
min
(
max_match
,
len
(
tokens
)
-
1
),
max_match_index
)
async
def
finalize
(
self
):
async
with
self
.
_not_empty
:
def
finalize
(
self
):
with
self
.
_not_empty
:
self
.
_shutdown
=
True
while
len
(
self
.
_available
)
<
self
.
num_caches
:
await
self
.
_not_empty
.
wait
()
self
.
_not_empty
.
wait
()
# All caches are now available
for
kvcache
in
self
.
_available
:
if
kvcache
is
not
None
:
kvcache
.
drop
(
self
.
model
)
...
...
scripts/launch_server.py
View file @
dcd6693f
import
asyncio
from
jiuge
import
JiugeForCauslLM
from
libinfinicore_infer
import
DeviceType
from
infer_task
import
InferTask
from
kvcache_pool
import
KVCachePool
import
queue
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
import
anyio
import
contextlib
import
uvicorn
import
time
import
uuid
import
sys
import
signal
import
json
import
threading
import
janus
if
len
(
sys
.
argv
)
<
3
:
print
(
...
...
@@ -40,26 +41,6 @@ else:
sys
.
exit
(
1
)
ndev
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
1
MODEL
=
JiugeForCauslLM
(
model_path
,
device_type
,
ndev
)
App
=
FastAPI
()
@
App
.
on_event
(
"startup"
)
async
def
setup
():
App
.
state
.
kv_cache_pool
=
KVCachePool
(
MODEL
,
1
)
async
def
handle_shutdown
():
await
App
.
state
.
kv_cache_pool
.
finalize
()
MODEL
.
destroy_model_instance
()
sys
.
exit
(
0
)
def
signal_handler
(
sig
,
frame
):
print
(
f
"Received signal
{
sig
}
, cleaning up..."
)
asyncio
.
create_task
(
handle_shutdown
())
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
# Handle Ctrl+C
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
# Handle docker stop / system shutdown
def
chunk_json
(
id_
,
content
=
None
,
role
=
None
,
finish_reason
=
None
):
delta
=
{}
...
...
@@ -84,50 +65,156 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
}
MAX_BATCH
=
3
print
(
f
"Using MAX_BATCH=
{
MAX_BATCH
}
. Try reduce this value if out of memory error occurs."
)
@
contextlib
.
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
# Startup
app
.
state
.
model
=
JiugeForCauslLM
(
model_path
,
device_type
,
ndev
)
app
.
state
.
kv_cache_pool
=
KVCachePool
(
app
.
state
.
model
,
MAX_BATCH
)
app
.
state
.
request_queue
=
janus
.
Queue
()
worker_thread
=
threading
.
Thread
(
target
=
worker_loop
,
args
=
(
app
,),
daemon
=
True
)
worker_thread
.
start
()
try
:
yield
# The app runs here
finally
:
# Shutdown
app
.
state
.
request_queue
.
sync_q
.
put
(
None
)
worker_thread
.
join
()
app
.
state
.
request_queue
.
shutdown
()
app
.
state
.
kv_cache_pool
.
finalize
()
app
.
state
.
model
.
destroy_model_instance
()
App
=
FastAPI
(
lifespan
=
lifespan
)
# App loop: take requests from the queue, do inference, and put unfinished requests back into the queue.
def
worker_loop
(
app
):
while
True
:
try
:
task
=
app
.
state
.
request_queue
.
sync_q
.
get
(
timeout
=
0.01
)
except
queue
.
Empty
:
continue
if
task
is
None
:
return
batch
=
[
task
]
while
len
(
batch
)
<
MAX_BATCH
:
try
:
req
=
app
.
state
.
request_queue
.
sync_q
.
get_nowait
()
if
req
is
not
None
:
batch
.
append
(
req
)
except
queue
.
Empty
:
break
output_tokens
=
app
.
state
.
model
.
batch_infer_one_round
(
batch
)
for
task
,
token
in
zip
(
batch
,
output_tokens
):
task
.
output
(
token
)
if
task
.
finish_reason
is
None
:
app
.
state
.
request_queue
.
sync_q
.
put
(
task
)
else
:
print
(
f
"[INFO] Task
{
task
.
id
}
finished infer."
)
app
.
state
.
kv_cache_pool
.
release_sync
(
task
)
def
build_task
(
id_
,
request_data
,
request
:
Request
):
messages
=
request_data
.
get
(
"messages"
,
[])
input_content
=
request
.
app
.
state
.
model
.
tokenizer
.
apply_chat_template
(
conversation
=
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
,
)
tokens
=
request
.
app
.
state
.
model
.
tokenizer
.
encode
(
input_content
)
return
InferTask
(
id_
,
tokens
,
request_data
.
get
(
"max_tokens"
,
request
.
app
.
state
.
model
.
max_context_len
()),
request_data
.
get
(
"temperature"
,
1.0
),
request_data
.
get
(
"top_k"
,
1
),
request_data
.
get
(
"top_p"
,
1.0
),
request
.
app
.
state
.
model
.
eos_token_id
,
)
async
def
chat_stream
(
id_
,
request_data
,
request
:
Request
):
try
:
infer_task
=
InferTask
(
id_
,
MODEL
.
tokenizer
,
request_data
)
await
App
.
state
.
kv_cache_pool
.
acquire
(
infer_task
)
infer_task
=
build_task
(
id_
,
request_data
,
request
)
await
request
.
app
.
state
.
kv_cache_pool
.
acquire
(
infer_task
)
# Initial empty content
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
content
=
""
,
role
=
"assistant"
),
ensure_ascii
=
False
,
chunk_json
(
id_
,
content
=
""
,
role
=
"assistant"
),
ensure_ascii
=
False
)
yield
f
"
{
chunk
}
\n\n
"
async
for
token
in
MODEL
.
chat_stream_async
(
infer_task
.
request
,
infer_task
.
kvcache
(),
):
request
.
app
.
state
.
request_queue
.
sync_q
.
put
(
infer_task
)
while
True
:
if
await
request
.
is_disconnected
():
print
(
"Client disconnected. Aborting stream."
)
break
if
(
infer_task
.
finish_reason
is
not
None
and
infer_task
.
output_queue
.
async_q
.
empty
()
):
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
content
=
toke
n
),
chunk_json
(
id_
,
finish_reason
=
infer_task
.
finish_reaso
n
),
ensure_ascii
=
False
,
)
yield
f
"
{
chunk
}
\n\n
"
finally
:
await
App
.
state
.
kv_cache_pool
.
release
(
infer_task
)
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
finish_reason
=
"stop"
),
ensure_ascii
=
False
,
break
token
=
await
infer_task
.
output_queue
.
async_q
.
get
()
content
=
(
request
.
app
.
state
.
model
.
tokenizer
.
_tokenizer
.
id_to_token
(
token
)
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
content
=
content
),
ensure_ascii
=
False
)
yield
f
"
{
chunk
}
\n\n
"
except
Exception
as
e
:
print
(
f
"[Error] ID :
{
id_
}
Exception:
{
e
}
"
)
async
def
chat
(
id_
,
request_data
):
infer_task
=
InferTask
(
id_
,
MODEL
.
tokenizer
,
request_data
)
await
App
.
state
.
kv_cache_pool
.
acquire
(
infer_task
)
output_text
=
MODEL
.
chat
(
infer_task
.
request
,
infer_task
.
kvcache
(),
async
def
chat
(
id_
,
request_data
,
request
:
Request
):
try
:
infer_task
=
build_task
(
id_
,
request_data
,
request
)
await
request
.
app
.
state
.
kv_cache_pool
.
acquire
(
infer_task
)
request
.
app
.
state
.
request_queue
.
sync_q
.
put
(
infer_task
)
output
=
[]
while
True
:
if
(
infer_task
.
finish_reason
is
not
None
and
infer_task
.
output_queue
.
async_q
.
empty
()
):
break
token
=
await
infer_task
.
output_queue
.
async_q
.
get
()
content
=
(
request
.
app
.
state
.
model
.
tokenizer
.
_tokenizer
.
id_to_token
(
token
)
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
output
.
append
(
content
)
output_text
=
""
.
join
(
output
).
strip
()
response
=
chunk_json
(
id_
,
content
=
output_text
.
strip
(),
role
=
"assistant"
,
finish_reason
=
"stop"
id_
,
content
=
output_text
,
role
=
"assistant"
,
finish_reason
=
infer_task
.
finish_reason
or
"stop"
,
)
await
App
.
state
.
kv_cache_pool
.
release
(
infer_task
)
return
JSONResponse
(
response
)
return
response
except
Exception
as
e
:
print
(
f
"[Error] ID:
{
id_
}
Exception:
{
e
}
"
)
return
JSONResponse
(
content
=
{
"error"
:
str
(
e
)},
status_code
=
500
)
@
App
.
post
(
"/chat/completions"
)
...
...
@@ -144,7 +231,7 @@ async def chat_completions(request: Request):
chat_stream
(
id_
,
data
,
request
),
media_type
=
"text/event-stream"
)
else
:
return
chat
(
id_
,
data
)
return
JSONResponse
(
chat
(
id_
,
data
)
)
if
__name__
==
"__main__"
:
...
...
scripts/libinfinicore_infer.py
View file @
dcd6693f
...
...
@@ -94,6 +94,7 @@ def __open_library__():
POINTER
(
c_int
),
# int const *dev_ids
]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
createKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
)
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
),
POINTER
(
KVCache
)]
lib
.
inferBatch
.
restype
=
None
...
...
@@ -105,10 +106,10 @@ def __open_library__():
c_uint
,
# unsigned int nreq
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
POINTER
(
KVCache
)),
# struct KVCache **kv_caches
POINTER
(
c_float
),
# float temperature
POINTER
(
c_uint
),
# unsigned int topk
POINTER
(
c_float
),
# float topp
POINTER
(
c_uint
),
# unsigned int *output
c_float
,
# float temperature
c_uint
,
# unsigned int topk
c_float
,
# float topp
]
return
lib
...
...
scripts/test_server.py
View file @
dcd6693f
...
...
@@ -5,13 +5,14 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
API_URL
=
"http://localhost:8000/chat/completions"
MODEL
=
"FM9G-7B"
PROMPT
=
[
"
给我讲个故事"
,
"山东最高的山是?
"
]
PROMPT
=
[
"
山东最高的山是?"
,
"给我讲个故事
"
]
CONCURRENCY
=
10
# 并发用户数量
def
single_run
(
user_id
):
payload
=
{
"model"
:
MODEL
,
"messages"
:
[{
"role"
:
"user"
,
"content"
:
PROMPT
[
user_id
%
len
(
PROMPT
)]}],
"max_tokens"
:
512
,
"stream"
:
True
}
headers
=
{
'Content-Type'
:
'application/json'
,
'Accept'
:
'application/json'
}
...
...
@@ -87,6 +88,9 @@ def main():
best_stream
=
r
[
'stream_time'
]
best
=
r
# Sort results by user ID
results
.
sort
(
key
=
lambda
x
:
x
[
"user"
])
with
open
(
"responses.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
fw
:
for
r
in
results
:
fw
.
write
(
f
"[User
{
r
[
'user'
]
}
]
\n
"
)
...
...
src/models/jiuge/jiuge.cpp
View file @
dcd6693f
...
...
@@ -115,8 +115,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
uint32_t
*
ans
,
float
temperature
,
uint32_t
topk
,
float
topp
)
{
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
uint32_t
*
output
)
{
auto
nlayer
=
meta
.
nlayer
;
auto
nkvh
=
meta
.
nkvh
/
ndev
;
auto
nh
=
meta
.
nh
/
ndev
;
...
...
@@ -457,8 +457,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI
(
infiniopRandomSample
(
desc_sample
,
workspace
,
workspace_size
,
result_buf
->
data
(
req
),
prob_buf
->
data
(
req
*
dvoc
),
random_val
,
topp
,
topk
,
temperature
,
stream
));
prob_buf
->
data
(
req
*
dvoc
),
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
],
stream
));
// result_buf->debug();
token_offset
+=
seq_len
;
}
...
...
@@ -466,7 +468,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI
(
infinirtMemcpy
(
result_cpu
.
data
(),
result_buf
->
data
(),
sizeof
(
int64_t
)
*
nreq
,
INFINIRT_MEMCPY_D2H
));
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
ans
[
req
]
=
result_cpu
[
req
];
output
[
req
]
=
result_cpu
[
req
];
}
}
...
...
@@ -500,15 +502,15 @@ inferBatch(struct JiugeModel *model,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
uint32_t
*
ans
,
float
temperature
,
uint32_t
topk
,
float
topp
)
{
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
uint32_t
*
output
)
{
model
->
req
.
tokens
=
tokens
;
model
->
req
.
ntok
=
ntok
;
model
->
req
.
req_lens
=
req_lens
;
model
->
req
.
nreq
=
nreq
;
model
->
req
.
req_pos
=
req_pos
;
model
->
req
.
kv_caches
=
kv_caches
;
model
->
req
.
ans
=
ans
;
model
->
req
.
output
=
output
;
model
->
req
.
temperature
=
temperature
;
model
->
req
.
topk
=
topk
;
model
->
req
.
topp
=
topp
;
...
...
@@ -547,7 +549,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
break
;
}
inferDeviceBatch
(
meta
,
*
rsrc
,
idev
,
ndev
,
req
.
tokens
,
req
.
ntok
,
req
.
req_lens
,
req
.
nreq
,
req
.
req_pos
,
req
.
kv_caches
,
req
.
ans
,
req
.
temperature
,
req
.
topk
,
req
.
topp
);
inferDeviceBatch
(
meta
,
*
rsrc
,
idev
,
ndev
,
req
.
tokens
,
req
.
ntok
,
req
.
req_lens
,
req
.
nreq
,
req
.
req_pos
,
req
.
kv_caches
,
req
.
temperature
,
req
.
topk
,
req
.
topp
,
req
.
output
);
state
.
proceed
=
false
;
lock
.
unlock
();
...
...
src/models/jiuge/jiuge_impl.hpp
View file @
dcd6693f
...
...
@@ -45,10 +45,10 @@ struct InferRequest {
uint32_t
nreq
;
const
uint32_t
*
req_pos
;
struct
KVCache
**
kv_caches
;
uint32_t
*
ans
;
float
temperature
;
uint32_
t
top
k
;
float
topp
;
const
float
*
temperature
;
const
uint32_t
*
topk
;
const
floa
t
*
top
p
;
uint32_t
*
output
;
};
struct
JiugeModel
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment