Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
69f18760
Commit
69f18760
authored
Jan 30, 2026
by
wooway777
Browse files
issue/204 - support graph in server scripts
parent
693d74d3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
22 deletions
+48
-22
python/infinilm/llm/llm.py
python/infinilm/llm/llm.py
+27
-13
python/infinilm/server/inference_server.py
python/infinilm/server/inference_server.py
+21
-9
No files found.
python/infinilm/llm/llm.py
View file @
69f18760
...
...
@@ -50,6 +50,7 @@ class EngineConfig:
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
"""
model_path
:
str
...
...
@@ -63,6 +64,7 @@ class EngineConfig:
temperature
:
float
=
1.0
top_p
:
float
=
0.8
top_k
:
int
=
1
enable_graph
:
bool
=
False
class
LLMEngine
:
...
...
@@ -74,11 +76,18 @@ class LLMEngine:
# Initialize device and dtype
self
.
_init_device
()
# Initialize KV cache
cache_config
=
PagedKVCacheConfig
(
num_blocks
=
config
.
num_blocks
,
block_size
=
config
.
block_size
)
# Initialize model engine
self
.
model_engine
=
InferEngine
(
model_path
=
config
.
model_path
,
device
=
self
.
device
,
distributed_config
=
DistConfig
(
config
.
tensor_parallel_size
),
cache_config
=
cache_config
,
enable_graph_compiling
=
config
.
enable_graph
,
)
# Load model weights
...
...
@@ -92,12 +101,6 @@ class LLMEngine:
)
self
.
_fix_tokenizer_decoder
()
# Initialize KV cache
cache_config
=
PagedKVCacheConfig
(
num_blocks
=
config
.
num_blocks
,
block_size
=
config
.
block_size
)
self
.
model_engine
.
reset_cache
(
cache_config
)
# Initialize scheduler
self
.
scheduler
=
Scheduler
(
max_batch_size
=
config
.
max_batch_size
,
...
...
@@ -113,6 +116,7 @@ class LLMEngine:
logger
.
info
(
f
"LLMEngine initialized with model at
{
config
.
model_path
}
"
f
"on device
{
config
.
device
}
"
f
"enable_graph=
{
config
.
enable_graph
}
"
)
def
_init_device
(
self
):
...
...
@@ -252,20 +256,22 @@ class LLMEngine:
for
stop_str
in
stop_strings
:
if
decoded_text
.
endswith
(
stop_str
):
# Remove the stop string from the end
decoded_text
=
decoded_text
[:
-
len
(
stop_str
)]
decoded_text
=
decoded_text
[:
-
len
(
stop_str
)]
req
.
generated_text
=
decoded_text
break
holds_back_incomplete_utf8
=
(
bool
(
decoded_text
)
and
decoded_text
.
endswith
(
"
\ufffd
"
)
)
holds_back_incomplete_utf8
=
bool
(
decoded_text
)
and
decoded_text
.
endswith
(
"
\ufffd
"
)
# vLLM-style: hold back only if we are not on the final chunk.
# Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token
# that completes the stop string, preventing additional tokens from being output.
if
(
holds_back_incomplete_utf8
and
not
finished_now
)
or
(
finished_now
and
req
.
finish_reason
in
(
FinishReason
.
LENGTH
,
FinishReason
.
STOP_STRING
)
finished_now
and
req
.
finish_reason
in
(
FinishReason
.
LENGTH
,
FinishReason
.
STOP_STRING
)
):
token_text
=
""
else
:
...
...
@@ -275,7 +281,9 @@ class LLMEngine:
req
.
_stream_last_yielded_length
=
len
(
decoded_text
)
# For non-streaming, finish checks happen here.
if
req
.
_output_queue
is
None
and
self
.
_check_request_finished
(
req
,
token_id
):
if
req
.
_output_queue
is
None
and
self
.
_check_request_finished
(
req
,
token_id
):
req
.
mark_finished
(
req
.
finish_reason
)
# Remove stop string from generated_text if STOP_STRING finish reason
if
req
.
finish_reason
==
FinishReason
.
STOP_STRING
:
...
...
@@ -283,7 +291,7 @@ class LLMEngine:
for
stop_str
in
stop_strings
:
if
req
.
generated_text
.
endswith
(
stop_str
):
# Remove the stop string from the end
req
.
generated_text
=
req
.
generated_text
[:
-
len
(
stop_str
)]
req
.
generated_text
=
req
.
generated_text
[:
-
len
(
stop_str
)]
break
# Put output in queue if it exists (for async streaming)
...
...
@@ -362,6 +370,7 @@ class LLM:
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
,
enable_graph
:
bool
=
False
,
):
"""Initialize LLM.
...
...
@@ -377,6 +386,7 @@ class LLM:
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
"""
config
=
EngineConfig
(
model_path
=
model_path
,
...
...
@@ -390,6 +400,7 @@ class LLM:
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
enable_graph
=
enable_graph
,
)
self
.
engine
=
LLMEngine
(
config
)
self
.
config
=
config
...
...
@@ -506,6 +517,7 @@ class AsyncLLMEngine:
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
,
enable_graph
:
bool
=
False
,
):
"""Initialize AsyncLLMEngine.
...
...
@@ -521,6 +533,7 @@ class AsyncLLMEngine:
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
"""
config
=
EngineConfig
(
model_path
=
model_path
,
...
...
@@ -534,6 +547,7 @@ class AsyncLLMEngine:
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
enable_graph
=
enable_graph
,
)
self
.
engine
=
LLMEngine
(
config
)
self
.
config
=
config
...
...
python/infinilm/server/inference_server.py
View file @
69f18760
...
...
@@ -23,7 +23,9 @@ DEFAULT_STREAM_TIMEOUT = 100.0
DEFAULT_REQUEST_TIMEOUT
=
1000.0
def
chunk_json
(
id_
,
content
=
None
,
role
=
None
,
finish_reason
=
None
,
model
:
str
=
"unknown"
):
def
chunk_json
(
id_
,
content
=
None
,
role
=
None
,
finish_reason
=
None
,
model
:
str
=
"unknown"
):
"""Generate JSON chunk for streaming response."""
delta
=
{}
if
content
:
...
...
@@ -66,6 +68,7 @@ class InferenceServer:
top_k
:
int
=
1
,
host
:
str
=
"0.0.0.0"
,
port
:
int
=
8000
,
enable_graph
:
bool
=
False
,
):
"""Initialize inference server.
...
...
@@ -83,6 +86,7 @@ class InferenceServer:
top_k: Default top-k sampling parameter.
host: Server host address.
port: Server port number.
enable_graph: Whether to enable graph compiling.
"""
self
.
model_path
=
model_path
# vLLM-like served model id: directory name of model_path
...
...
@@ -99,6 +103,7 @@ class InferenceServer:
self
.
top_k
=
top_k
self
.
host
=
host
self
.
port
=
port
self
.
enable_graph
=
enable_graph
self
.
engine
:
AsyncLLMEngine
=
None
...
...
@@ -126,9 +131,11 @@ class InferenceServer:
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
enable_graph
=
self
.
enable_graph
,
)
self
.
engine
.
start
()
logger
.
info
(
f
"Engine initialized with model at
{
self
.
model_path
}
"
)
logger
.
info
(
f
" enable_graph:
{
self
.
enable_graph
}
"
)
yield
self
.
engine
.
stop
()
...
...
@@ -233,7 +240,6 @@ class InferenceServer:
if
isinstance
(
stop
,
str
):
stop
=
[
stop
]
return
SamplingParams
(
temperature
=
float
(
pick
(
"temperature"
,
self
.
temperature
)),
top_p
=
float
(
pick
(
"top_p"
,
self
.
top_p
)),
...
...
@@ -291,15 +297,15 @@ class InferenceServer:
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids
=
self
.
engine
.
engine
.
eos_token_ids
is_eos_token
=
(
eos_token_ids
and
token_output
.
token_id
in
eos_token_ids
)
is_eos_token
=
eos_token_ids
and
token_output
.
token_id
in
eos_token_ids
if
not
is_eos_token
and
token_output
.
token_text
:
# Send token
chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
content
=
token_output
.
token_text
,
model
=
self
.
model_id
request_id
,
content
=
token_output
.
token_text
,
model
=
self
.
model_id
,
),
ensure_ascii
=
False
,
)
...
...
@@ -379,9 +385,7 @@ class InferenceServer:
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids
=
self
.
engine
.
engine
.
eos_token_ids
is_eos_token
=
(
eos_token_ids
and
token_output
.
token_id
in
eos_token_ids
)
is_eos_token
=
eos_token_ids
and
token_output
.
token_id
in
eos_token_ids
if
not
is_eos_token
:
output_text
+=
token_output
.
token_text
...
...
@@ -483,6 +487,11 @@ def parse_args():
parser
.
add_argument
(
"--moore"
,
action
=
"store_true"
,
help
=
"Use Moore device"
)
parser
.
add_argument
(
"--iluvatar"
,
action
=
"store_true"
,
help
=
"Use Iluvatar device"
)
parser
.
add_argument
(
"--cambricon"
,
action
=
"store_true"
,
help
=
"Use Cambricon device"
)
parser
.
add_argument
(
"--enable-graph"
,
action
=
"store_true"
,
help
=
"Enable graph compiling"
,
)
parser
.
add_argument
(
"--log_level"
,
type
=
str
,
...
...
@@ -518,6 +527,8 @@ def main():
"
\n
"
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
"
\n
"
"Optional: --enable-paged-attn --enable-graph"
)
sys
.
exit
(
1
)
...
...
@@ -535,6 +546,7 @@ def main():
top_k
=
args
.
top_k
,
host
=
args
.
host
,
port
=
args
.
port
,
enable_graph
=
args
.
enable_graph
,
)
server
.
start
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment