Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
96ecf490
Unverified
Commit
96ecf490
authored
Feb 06, 2026
by
Haojie Wang
Committed by
GitHub
Feb 06, 2026
Browse files
Merge pull request #203 from InfiniTensor/issue/193
Issue/193: inference_server适配部署需求
parents
f2d9d397
0cddb99e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
242 additions
and
37 deletions
+242
-37
python/infinilm/llm/cache_manager.py
python/infinilm/llm/cache_manager.py
+6
-0
python/infinilm/llm/llm.py
python/infinilm/llm/llm.py
+86
-5
python/infinilm/llm/request.py
python/infinilm/llm/request.py
+4
-0
python/infinilm/llm/sampling_params.py
python/infinilm/llm/sampling_params.py
+1
-1
python/infinilm/llm/scheduler.py
python/infinilm/llm/scheduler.py
+40
-2
python/infinilm/server/inference_server.py
python/infinilm/server/inference_server.py
+104
-28
test/bench/test_benchmark.py
test/bench/test_benchmark.py
+1
-1
No files found.
python/infinilm/llm/cache_manager.py
View file @
96ecf490
...
@@ -261,6 +261,12 @@ class BlockManager:
...
@@ -261,6 +261,12 @@ class BlockManager:
def
get_num_free_blocks
(
self
)
->
int
:
def
get_num_free_blocks
(
self
)
->
int
:
return
len
(
self
.
free_block_ids
)
return
len
(
self
.
free_block_ids
)
def
get_total_usable_blocks
(
self
)
->
int
:
freeable_used_blocks
=
sum
(
1
for
bid
in
self
.
used_block_ids
if
self
.
blocks
[
bid
].
ref_count
==
0
)
return
len
(
self
.
free_block_ids
)
+
freeable_used_blocks
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
return
(
f
"BlockManager(blocks=
{
self
.
num_blocks
}
, block_size=
{
self
.
block_size
}
, "
f
"BlockManager(blocks=
{
self
.
num_blocks
}
, block_size=
{
self
.
block_size
}
, "
...
...
python/infinilm/llm/llm.py
View file @
96ecf490
...
@@ -228,12 +228,63 @@ class LLMEngine:
...
@@ -228,12 +228,63 @@ class LLMEngine:
req
.
generated_token_ids
.
append
(
token_id
)
req
.
generated_token_ids
.
append
(
token_id
)
if
req
.
is_prefill
:
if
req
.
is_prefill
:
req
.
is_prefill
=
False
req
.
is_prefill
=
False
# vLLM-style replacement character handling is primarily relevant for streaming.
# For offline generation (no output queue), keep the fast incremental path.
if
req
.
_output_queue
is
None
:
token_text
=
self
.
detokenize
([
token_id
])
req
.
generated_text
+=
token_text
else
:
# Streaming path: compute delta from a full decode so we can hold back
# trailing '\ufffd' (likely an incomplete UTF-8 sequence).
decoded_text
=
self
.
detokenize
(
req
.
generated_token_ids
)
finished_now
=
False
# Update generated_text to the latest decode (used for stop-string checks and debugging)
req
.
generated_text
=
decoded_text
if
self
.
_check_request_finished
(
req
,
token_id
):
req
.
mark_finished
(
req
.
finish_reason
)
finished_now
=
True
# Remove stop string from generated_text if STOP_STRING finish reason
if
req
.
finish_reason
==
FinishReason
.
STOP_STRING
:
stop_strings
=
req
.
sampling_params
.
stop
or
[]
for
stop_str
in
stop_strings
:
if
decoded_text
.
endswith
(
stop_str
):
# Remove the stop string from the end
decoded_text
=
decoded_text
[:
-
len
(
stop_str
)]
req
.
generated_text
=
decoded_text
break
holds_back_incomplete_utf8
=
(
bool
(
decoded_text
)
and
decoded_text
.
endswith
(
"
\ufffd
"
)
)
token_text
=
self
.
tokenizer
.
decode
(
token_id
)
# vLLM-style: hold back only if we are not on the final chunk.
req
.
generated_text
+=
token_text
# Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token
if
self
.
_check_request_finished
(
req
,
token_id
):
# that completes the stop string, preventing additional tokens from being output.
if
(
holds_back_incomplete_utf8
and
not
finished_now
)
or
(
finished_now
and
req
.
finish_reason
in
(
FinishReason
.
LENGTH
,
FinishReason
.
STOP_STRING
)
):
token_text
=
""
else
:
last_len
=
getattr
(
req
,
"_stream_last_yielded_length"
,
0
)
token_text
=
decoded_text
[
last_len
:]
if
token_text
:
req
.
_stream_last_yielded_length
=
len
(
decoded_text
)
# For non-streaming, finish checks happen here.
if
req
.
_output_queue
is
None
and
self
.
_check_request_finished
(
req
,
token_id
):
req
.
mark_finished
(
req
.
finish_reason
)
req
.
mark_finished
(
req
.
finish_reason
)
# Remove stop string from generated_text if STOP_STRING finish reason
if
req
.
finish_reason
==
FinishReason
.
STOP_STRING
:
stop_strings
=
req
.
sampling_params
.
stop
or
[]
for
stop_str
in
stop_strings
:
if
req
.
generated_text
.
endswith
(
stop_str
):
# Remove the stop string from the end
req
.
generated_text
=
req
.
generated_text
[:
-
len
(
stop_str
)]
break
# Put output in queue if it exists (for async streaming)
# Put output in queue if it exists (for async streaming)
if
req
.
_output_queue
is
not
None
:
if
req
.
_output_queue
is
not
None
:
...
@@ -283,12 +334,15 @@ class LLMEngine:
...
@@ -283,12 +334,15 @@ class LLMEngine:
self
,
self
,
messages
:
List
[
dict
],
messages
:
List
[
dict
],
add_generation_prompt
:
bool
=
True
,
add_generation_prompt
:
bool
=
True
,
chat_template_kwargs
:
Optional
[
dict
]
=
None
,
)
->
str
:
)
->
str
:
"""Apply chat template to messages."""
"""Apply chat template to messages."""
chat_template_kwargs
=
chat_template_kwargs
or
{}
return
self
.
tokenizer
.
apply_chat_template
(
return
self
.
tokenizer
.
apply_chat_template
(
conversation
=
messages
,
conversation
=
messages
,
add_generation_prompt
=
add_generation_prompt
,
add_generation_prompt
=
add_generation_prompt
,
tokenize
=
False
,
tokenize
=
False
,
**
chat_template_kwargs
,
)
)
...
@@ -486,6 +540,10 @@ class AsyncLLMEngine:
...
@@ -486,6 +540,10 @@ class AsyncLLMEngine:
self
.
_running
=
False
self
.
_running
=
False
self
.
_step_thread
:
Optional
[
threading
.
Thread
]
=
None
self
.
_step_thread
:
Optional
[
threading
.
Thread
]
=
None
self
.
_healthy
=
True
def
is_healthy
(
self
)
->
bool
:
return
bool
(
self
.
_healthy
)
def
start
(
self
):
def
start
(
self
):
"""Start the background inference loop."""
"""Start the background inference loop."""
...
@@ -520,6 +578,7 @@ class AsyncLLMEngine:
...
@@ -520,6 +578,7 @@ class AsyncLLMEngine:
time
.
sleep
(
0.01
)
time
.
sleep
(
0.01
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error in step loop:
{
e
}
"
,
exc_info
=
True
)
logger
.
error
(
f
"Error in step loop:
{
e
}
"
,
exc_info
=
True
)
self
.
_healthy
=
False
self
.
_running
=
False
self
.
_running
=
False
break
break
...
@@ -581,6 +640,8 @@ class AsyncLLMEngine:
...
@@ -581,6 +640,8 @@ class AsyncLLMEngine:
request_id
:
Optional
[
str
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
request_data
:
Optional
[
dict
]
=
None
,
request_data
:
Optional
[
dict
]
=
None
,
http_request
:
Optional
[
any
]
=
None
,
http_request
:
Optional
[
any
]
=
None
,
add_generation_prompt
:
bool
=
True
,
chat_template_kwargs
:
Optional
[
dict
]
=
None
,
)
->
InferenceRequest
:
)
->
InferenceRequest
:
"""Add a chat request to the engine.
"""Add a chat request to the engine.
...
@@ -594,7 +655,11 @@ class AsyncLLMEngine:
...
@@ -594,7 +655,11 @@ class AsyncLLMEngine:
Returns:
Returns:
The created InferenceRequest object.
The created InferenceRequest object.
"""
"""
prompt
=
self
.
engine
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
)
prompt
=
self
.
engine
.
apply_chat_template
(
messages
,
add_generation_prompt
=
add_generation_prompt
,
chat_template_kwargs
=
chat_template_kwargs
,
)
return
self
.
add_request
(
return
self
.
add_request
(
prompt
=
prompt
,
prompt
=
prompt
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
@@ -607,6 +672,7 @@ class AsyncLLMEngine:
...
@@ -607,6 +672,7 @@ class AsyncLLMEngine:
self
,
self
,
request
:
InferenceRequest
,
request
:
InferenceRequest
,
timeout
:
float
=
100.0
,
timeout
:
float
=
100.0
,
request_timeout
:
Optional
[
float
]
=
None
,
)
->
AsyncIterator
[
TokenOutput
]:
)
->
AsyncIterator
[
TokenOutput
]:
"""Stream tokens from a request.
"""Stream tokens from a request.
...
@@ -619,6 +685,7 @@ class AsyncLLMEngine:
...
@@ -619,6 +685,7 @@ class AsyncLLMEngine:
"""
"""
import
asyncio
import
asyncio
start
=
time
.
time
()
while
True
:
while
True
:
if
request
.
is_finished
()
and
request
.
output_queue
.
async_q
.
empty
():
if
request
.
is_finished
()
and
request
.
output_queue
.
async_q
.
empty
():
break
break
...
@@ -635,6 +702,20 @@ class AsyncLLMEngine:
...
@@ -635,6 +702,20 @@ class AsyncLLMEngine:
if
token_output
.
finished
:
if
token_output
.
finished
:
break
break
except
asyncio
.
TimeoutError
:
except
asyncio
.
TimeoutError
:
# Enforce request-level timeout even if no tokens are produced.
if
request_timeout
is
not
None
:
now
=
time
.
time
()
if
now
-
start
>
float
(
request_timeout
):
request
.
mark_timeout
()
yield
TokenOutput
(
request_id
=
request
.
request_id
,
token_id
=-
1
,
token_text
=
""
,
finished
=
True
,
finish_reason
=
FinishReason
.
TIMEOUT
,
generated_text
=
request
.
generated_text
,
)
break
if
request
.
is_finished
():
if
request
.
is_finished
():
break
break
continue
continue
...
...
python/infinilm/llm/request.py
View file @
96ecf490
...
@@ -144,6 +144,10 @@ class InferenceRequest:
...
@@ -144,6 +144,10 @@ class InferenceRequest:
# Output management (for async streaming)
# Output management (for async streaming)
self
.
_output_queue
:
Optional
[
janus
.
Queue
]
=
None
self
.
_output_queue
:
Optional
[
janus
.
Queue
]
=
None
# Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer)
# Used by the engine to compute "delta" text chunks from a full decode.
self
.
_stream_last_yielded_length
:
int
=
0
@
property
@
property
def
output_queue
(
self
)
->
janus
.
Queue
:
def
output_queue
(
self
)
->
janus
.
Queue
:
"""Lazy initialization of output queue."""
"""Lazy initialization of output queue."""
...
...
python/infinilm/llm/sampling_params.py
View file @
96ecf490
...
@@ -15,7 +15,7 @@ class SamplingParams:
...
@@ -15,7 +15,7 @@ class SamplingParams:
top_k
:
int
=
1
top_k
:
int
=
1
max_tokens
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
None
stop
:
Optional
[
List
[
str
]]
=
None
stop
:
Optional
[
List
[
str
]]
=
None
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
# Placeholder for future usage, not currently handled
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
stop
is
None
:
if
self
.
stop
is
None
:
...
...
python/infinilm/llm/scheduler.py
View file @
96ecf490
...
@@ -155,12 +155,21 @@ class Scheduler:
...
@@ -155,12 +155,21 @@ class Scheduler:
except
queue
.
Empty
:
except
queue
.
Empty
:
break
break
if
not
self
.
can_accept_request
(
req
):
self
.
waiting_queue
.
sync_q
.
put
(
req
)
break
# Skip requests that were already finished (e.g., timed out/canceled while waiting)
if
req
.
is_finished
():
self
.
complete_requests
([
req
])
continue
req_tokens
=
req
.
get_input_tokens
()
req_tokens
=
req
.
get_input_tokens
()
num_required_blocks
=
req
.
get_num_blocks_required
(
self
.
block_size
)
num_required_blocks
=
req
.
get_num_blocks_required
(
self
.
block_size
)
if
not
self
.
cache_manager
.
can_allocate
(
num_required_blocks
):
if
not
self
.
cache_manager
.
can_allocate
(
num_required_blocks
):
if
not
self
.
cache_manager
.
try_free_blocks
(
num_required_blocks
):
if
not
self
.
cache_manager
.
try_free_blocks
(
num_required_blocks
):
raise
RuntimeError
(
"No available cache blocks"
)
raise
RuntimeError
(
"No available cache blocks
for new request
"
)
# Allocate blocks with automatic prefix caching support
# Allocate blocks with automatic prefix caching support
req
.
block_table
,
req
.
slot_mapping
,
req
.
num_cached_tokens
=
(
req
.
block_table
,
req
.
slot_mapping
,
req
.
num_cached_tokens
=
(
...
@@ -185,6 +194,10 @@ class Scheduler:
...
@@ -185,6 +194,10 @@ class Scheduler:
req
=
self
.
running_queue
.
sync_q
.
get_nowait
()
req
=
self
.
running_queue
.
sync_q
.
get_nowait
()
except
queue
.
Empty
:
except
queue
.
Empty
:
break
break
# Skip requests that were already finished (e.g., timed out/canceled while running)
if
req
.
is_finished
():
self
.
complete_requests
([
req
])
continue
# Decode phase: allocate slot for newly generated token
# Decode phase: allocate slot for newly generated token
try
:
try
:
...
@@ -197,7 +210,7 @@ class Scheduler:
...
@@ -197,7 +210,7 @@ class Scheduler:
scheduled_requests
.
append
(
req
)
scheduled_requests
.
append
(
req
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
RuntimeError
(
"No available cache blocks"
)
from
e
raise
RuntimeError
(
"No available cache blocks
for new token
"
)
from
e
# Return decode batch if any running requests were scheduled
# Return decode batch if any running requests were scheduled
if
scheduled_requests
:
if
scheduled_requests
:
...
@@ -237,6 +250,31 @@ class Scheduler:
...
@@ -237,6 +250,31 @@ class Scheduler:
# Still running, put back in running queue
# Still running, put back in running queue
self
.
running_queue
.
sync_q
.
put
(
req
)
self
.
running_queue
.
sync_q
.
put
(
req
)
def
can_accept_request
(
self
,
request
:
InferenceRequest
)
->
bool
:
total_required_blocks
=
0
# Calculate blocks needed for running requests
running_queue_size
=
self
.
running_queue
.
sync_q
.
qsize
()
for
_
in
range
(
running_queue_size
):
req
=
self
.
running_queue
.
sync_q
.
get
()
remaining_tokens
=
(
req
.
sampling_params
.
max_tokens
-
req
.
get_num_generated_tokens
()
)
num_blocks_needed
=
(
remaining_tokens
+
self
.
block_size
-
1
)
//
self
.
block_size
total_required_blocks
+=
num_blocks_needed
self
.
running_queue
.
sync_q
.
put
(
req
)
# Calculate blocks needed for the new request
total_length
=
request
.
get_prompt_length
()
total_length
+=
request
.
sampling_params
.
max_tokens
num_blocks_needed
=
(
total_length
+
self
.
block_size
-
1
)
//
self
.
block_size
total_required_blocks
+=
num_blocks_needed
# Compare with total usable blocks in cache manager
return
total_required_blocks
<=
self
.
cache_manager
.
get_total_usable_blocks
()
def
get_cache_stats
(
self
)
->
dict
:
def
get_cache_stats
(
self
)
->
dict
:
"""Get cache statistics."""
"""Get cache statistics."""
return
{
return
{
...
...
python/infinilm/server/inference_server.py
View file @
96ecf490
...
@@ -10,6 +10,7 @@ import uuid
...
@@ -10,6 +10,7 @@ import uuid
import
argparse
import
argparse
import
uvicorn
import
uvicorn
import
logging
import
logging
import
os
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
...
@@ -22,7 +23,7 @@ DEFAULT_STREAM_TIMEOUT = 100.0
...
@@ -22,7 +23,7 @@ DEFAULT_STREAM_TIMEOUT = 100.0
DEFAULT_REQUEST_TIMEOUT
=
1000.0
DEFAULT_REQUEST_TIMEOUT
=
1000.0
def
chunk_json
(
id_
,
content
=
None
,
role
=
None
,
finish_reason
=
None
):
def
chunk_json
(
id_
,
content
=
None
,
role
=
None
,
finish_reason
=
None
,
model
:
str
=
"unknown"
):
"""Generate JSON chunk for streaming response."""
"""Generate JSON chunk for streaming response."""
delta
=
{}
delta
=
{}
if
content
:
if
content
:
...
@@ -33,7 +34,7 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
...
@@ -33,7 +34,7 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
"id"
:
id_
,
"id"
:
id_
,
"object"
:
"chat.completion.chunk"
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"created"
:
int
(
time
.
time
()),
"model"
:
"jiuge"
,
"model"
:
model
,
"system_fingerprint"
:
None
,
"system_fingerprint"
:
None
,
"choices"
:
[
"choices"
:
[
{
{
...
@@ -84,6 +85,8 @@ class InferenceServer:
...
@@ -84,6 +85,8 @@ class InferenceServer:
port: Server port number.
port: Server port number.
"""
"""
self
.
model_path
=
model_path
self
.
model_path
=
model_path
# vLLM-like served model id: directory name of model_path
self
.
model_id
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
model_path
))
or
"model"
self
.
device
=
device
self
.
device
=
device
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
...
@@ -136,7 +139,10 @@ class InferenceServer:
...
@@ -136,7 +139,10 @@ class InferenceServer:
def
_register_routes
(
self
,
app
:
FastAPI
):
def
_register_routes
(
self
,
app
:
FastAPI
):
"""Register API routes."""
"""Register API routes."""
# OpenAI-compatible chat completions endpoint.
# Support both legacy path and OpenAI-style /v1 prefix for proxy/router compatibility.
@
app
.
post
(
"/chat/completions"
)
@
app
.
post
(
"/chat/completions"
)
@
app
.
post
(
"/v1/chat/completions"
)
async
def
chat_completions
(
request
:
Request
):
async
def
chat_completions
(
request
:
Request
):
try
:
try
:
data
=
await
request
.
json
()
data
=
await
request
.
json
()
...
@@ -169,15 +175,21 @@ class InferenceServer:
...
@@ -169,15 +175,21 @@ class InferenceServer:
@
app
.
get
(
"/health"
)
@
app
.
get
(
"/health"
)
async
def
health
():
async
def
health
():
# Expose engine health so babysitter/registry can treat backend as unhealthy.
if
(
self
.
engine
is
not
None
and
hasattr
(
self
.
engine
,
"is_healthy"
)
and
not
self
.
engine
.
is_healthy
()
):
return
JSONResponse
(
content
=
{
"status"
:
"unhealthy"
},
status_code
=
503
)
return
{
"status"
:
"healthy"
}
return
{
"status"
:
"healthy"
}
@
app
.
get
(
"/v1/models"
)
def
_models_payload
():
async
def
list_models
():
return
{
return
{
"object"
:
"list"
,
"object"
:
"list"
,
"data"
:
[
"data"
:
[
{
{
"id"
:
"jiuge"
,
"id"
:
self
.
model_id
,
"object"
:
"model"
,
"object"
:
"model"
,
"created"
:
int
(
time
.
time
()),
"created"
:
int
(
time
.
time
()),
"owned_by"
:
"infinilm"
,
"owned_by"
:
"infinilm"
,
...
@@ -185,20 +197,54 @@ class InferenceServer:
...
@@ -185,20 +197,54 @@ class InferenceServer:
],
],
}
}
# Support both /v1/models (OpenAI) and /models (common legacy) for compatibility.
@
app
.
get
(
"/v1/models"
)
async
def
list_models
():
return
_models_payload
()
@
app
.
get
(
"/models"
)
async
def
list_models_legacy
():
return
_models_payload
()
def
_build_sampling_params
(
self
,
data
:
dict
)
->
SamplingParams
:
def
_build_sampling_params
(
self
,
data
:
dict
)
->
SamplingParams
:
"""Build SamplingParams from request data."""
"""Build SamplingParams from request data."""
# Support both:
# - top-level OpenAI-ish fields: temperature/top_p/top_k/max_tokens/stop
# - nested dict: sampling_params: { ... }
sp
=
data
.
get
(
"sampling_params"
)
or
{}
if
not
isinstance
(
sp
,
dict
):
sp
=
{}
def
pick
(
key
:
str
,
default
):
# Priority: explicit top-level field > nested sampling_params > server default
if
key
in
data
and
data
.
get
(
key
)
is
not
None
:
return
data
.
get
(
key
)
if
key
in
sp
and
sp
.
get
(
key
)
is
not
None
:
return
sp
.
get
(
key
)
return
default
# Accept common alias
max_tokens
=
pick
(
"max_tokens"
,
self
.
max_tokens
)
if
max_tokens
is
None
:
# Some clients use max_new_tokens
max_tokens
=
pick
(
"max_new_tokens"
,
self
.
max_tokens
)
stop
=
pick
(
"stop"
,
None
)
if
isinstance
(
stop
,
str
):
stop
=
[
stop
]
return
SamplingParams
(
return
SamplingParams
(
temperature
=
data
.
get
(
"temperature"
,
self
.
temperature
),
temperature
=
float
(
pick
(
"temperature"
,
self
.
temperature
)
)
,
top_p
=
data
.
get
(
"top_p"
,
self
.
top_p
),
top_p
=
float
(
pick
(
"top_p"
,
self
.
top_p
)
)
,
top_k
=
data
.
get
(
"top_k"
,
self
.
top_k
),
top_k
=
int
(
pick
(
"top_k"
,
self
.
top_k
)
)
,
max_tokens
=
data
.
ge
t
(
"
max_tokens
"
,
self
.
max_tokens
)
,
max_tokens
=
in
t
(
max_tokens
)
if
max_tokens
is
not
None
else
None
,
stop
=
data
.
get
(
"
stop
"
)
,
stop
=
stop
,
)
)
async
def
_stream_chat
(
self
,
request_id
:
str
,
data
:
dict
,
http_request
:
Request
):
async
def
_stream_chat
(
self
,
request_id
:
str
,
data
:
dict
,
http_request
:
Request
):
"""Handle streaming chat request."""
"""Handle streaming chat request."""
req
=
None
req
=
None
start_time
=
time
.
time
()
try
:
try
:
messages
=
data
.
get
(
"messages"
,
[])
messages
=
data
.
get
(
"messages"
,
[])
...
@@ -210,22 +256,26 @@ class InferenceServer:
...
@@ -210,22 +256,26 @@ class InferenceServer:
request_id
=
request_id
,
request_id
=
request_id
,
request_data
=
data
,
request_data
=
data
,
http_request
=
http_request
,
http_request
=
http_request
,
add_generation_prompt
=
bool
(
data
.
get
(
"add_generation_prompt"
,
True
)),
chat_template_kwargs
=
data
.
get
(
"chat_template_kwargs"
)
or
{},
)
)
async
for
token_output
in
self
.
engine
.
stream_request
(
async
for
token_output
in
self
.
engine
.
stream_request
(
req
,
timeout
=
DEFAULT_STREAM_TIMEOUT
req
,
timeout
=
DEFAULT_STREAM_TIMEOUT
,
request_timeout
=
DEFAULT_REQUEST_TIMEOUT
,
):
):
#
Check timeout
#
If stream_request enforces timeout, we can just surface the state to the client.
if
t
ime
.
time
()
-
start_time
>
DEFAULT_REQUEST_
TIMEOUT
:
if
t
oken_output
.
finish_reason
==
FinishReason
.
TIMEOUT
:
logger
.
warning
(
logger
.
warning
(
f
"Request
{
request_id
}
timed out after
{
DEFAULT_REQUEST_TIMEOUT
}
s"
f
"Request
{
request_id
}
timed out after
{
DEFAULT_REQUEST_TIMEOUT
}
s"
)
)
req
.
mark_timeout
()
error_chunk
=
json
.
dumps
(
error_chunk
=
json
.
dumps
(
chunk_json
(
chunk_json
(
request_id
,
request_id
,
content
=
"[Request timeout]"
,
content
=
"[Request timeout]"
,
finish_reason
=
"timeout"
,
finish_reason
=
"timeout"
,
model
=
self
.
model_id
,
),
),
ensure_ascii
=
False
,
ensure_ascii
=
False
,
)
)
...
@@ -238,19 +288,31 @@ class InferenceServer:
...
@@ -238,19 +288,31 @@ class InferenceServer:
req
.
mark_canceled
()
req
.
mark_canceled
()
break
break
# Send token
# Skip EOS token text for OpenAI API compatibility
chunk
=
json
.
dumps
(
# Check if this token is an EOS token by comparing token_id with eos_token_ids
chunk_json
(
request_id
,
content
=
token_output
.
token_text
),
eos_token_ids
=
self
.
engine
.
engine
.
eos_token_ids
ensure_ascii
=
False
,
is_eos_token
=
(
eos_token_ids
and
token_output
.
token_id
in
eos_token_ids
)
)
yield
f
"data:
{
chunk
}
\n\n
"
if
not
is_eos_token
and
token_output
.
token_text
:
# Send token
chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
content
=
token_output
.
token_text
,
model
=
self
.
model_id
),
ensure_ascii
=
False
,
)
yield
f
"data:
{
chunk
}
\n\n
"
if
token_output
.
finished
:
if
token_output
.
finished
:
finish_reason
=
self
.
_convert_finish_reason
(
finish_reason
=
self
.
_convert_finish_reason
(
token_output
.
finish_reason
token_output
.
finish_reason
)
)
chunk
=
json
.
dumps
(
chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
finish_reason
=
finish_reason
),
chunk_json
(
request_id
,
finish_reason
=
finish_reason
,
model
=
self
.
model_id
),
ensure_ascii
=
False
,
ensure_ascii
=
False
,
)
)
yield
f
"data:
{
chunk
}
\n\n
"
yield
f
"data:
{
chunk
}
\n\n
"
...
@@ -262,7 +324,10 @@ class InferenceServer:
...
@@ -262,7 +324,10 @@ class InferenceServer:
req
.
mark_failed
()
req
.
mark_failed
()
error_chunk
=
json
.
dumps
(
error_chunk
=
json
.
dumps
(
chunk_json
(
chunk_json
(
request_id
,
content
=
f
"[Error:
{
str
(
e
)
}
]"
,
finish_reason
=
"error"
request_id
,
content
=
f
"[Error:
{
str
(
e
)
}
]"
,
finish_reason
=
"error"
,
model
=
self
.
model_id
,
),
),
ensure_ascii
=
False
,
ensure_ascii
=
False
,
)
)
...
@@ -278,7 +343,6 @@ class InferenceServer:
...
@@ -278,7 +343,6 @@ class InferenceServer:
async
def
_chat
(
self
,
request_id
:
str
,
data
:
dict
,
http_request
:
Request
):
async
def
_chat
(
self
,
request_id
:
str
,
data
:
dict
,
http_request
:
Request
):
"""Handle non-streaming chat request."""
"""Handle non-streaming chat request."""
req
=
None
req
=
None
start_time
=
time
.
time
()
try
:
try
:
messages
=
data
.
get
(
"messages"
,
[])
messages
=
data
.
get
(
"messages"
,
[])
...
@@ -290,17 +354,20 @@ class InferenceServer:
...
@@ -290,17 +354,20 @@ class InferenceServer:
request_id
=
request_id
,
request_id
=
request_id
,
request_data
=
data
,
request_data
=
data
,
http_request
=
http_request
,
http_request
=
http_request
,
add_generation_prompt
=
bool
(
data
.
get
(
"add_generation_prompt"
,
True
)),
chat_template_kwargs
=
data
.
get
(
"chat_template_kwargs"
)
or
{},
)
)
# Collect all generated tokens
# Collect all generated tokens
output_text
=
""
output_text
=
""
async
for
token_output
in
self
.
engine
.
stream_request
(
async
for
token_output
in
self
.
engine
.
stream_request
(
req
,
timeout
=
DEFAULT_STREAM_TIMEOUT
req
,
timeout
=
DEFAULT_STREAM_TIMEOUT
,
request_timeout
=
DEFAULT_REQUEST_TIMEOUT
,
):
):
#
Check timeout
#
Request-level timeout is handled inside stream_request.
if
t
ime
.
time
()
-
start_time
>
DEFAULT_REQUEST_
TIMEOUT
:
if
t
oken_output
.
finish_reason
==
FinishReason
.
TIMEOUT
:
logger
.
warning
(
f
"Request
{
request_id
}
timed out"
)
logger
.
warning
(
f
"Request
{
request_id
}
timed out"
)
req
.
mark_timeout
()
break
break
# Check client disconnect
# Check client disconnect
...
@@ -309,7 +376,15 @@ class InferenceServer:
...
@@ -309,7 +376,15 @@ class InferenceServer:
req
.
mark_canceled
()
req
.
mark_canceled
()
break
break
output_text
+=
token_output
.
token_text
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids
=
self
.
engine
.
engine
.
eos_token_ids
is_eos_token
=
(
eos_token_ids
and
token_output
.
token_id
in
eos_token_ids
)
if
not
is_eos_token
:
output_text
+=
token_output
.
token_text
if
token_output
.
finished
:
if
token_output
.
finished
:
break
break
...
@@ -322,6 +397,7 @@ class InferenceServer:
...
@@ -322,6 +397,7 @@ class InferenceServer:
content
=
output_text
,
content
=
output_text
,
role
=
"assistant"
,
role
=
"assistant"
,
finish_reason
=
finish_reason
or
"stop"
,
finish_reason
=
finish_reason
or
"stop"
,
model
=
self
.
model_id
,
)
)
return
response
return
response
...
...
test/bench/test_benchmark.py
View file @
96ecf490
...
@@ -4,7 +4,6 @@ import argparse
...
@@ -4,7 +4,6 @@ import argparse
import
time
import
time
import
re
import
re
import
csv
import
csv
from
datasets
import
load_dataset
,
Dataset
import
numpy
as
np
import
numpy
as
np
import
infinicore
import
infinicore
from
infinilm.modeling_utils
import
load_model_state_dict_by_file
from
infinilm.modeling_utils
import
load_model_state_dict_by_file
...
@@ -12,6 +11,7 @@ from infinilm.distributed import DistConfig
...
@@ -12,6 +11,7 @@ from infinilm.distributed import DistConfig
from
infinilm.cache
import
StaticKVCacheConfig
from
infinilm.cache
import
StaticKVCacheConfig
from
infinilm.infer_engine
import
GenerationConfig
,
InferEngine
from
infinilm.infer_engine
import
GenerationConfig
,
InferEngine
from
infinilm.cache
import
StaticKVCacheConfig
from
infinilm.cache
import
StaticKVCacheConfig
from
datasets
import
load_dataset
,
Dataset
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment