Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
3b8e1cb7
Unverified
Commit
3b8e1cb7
authored
Mar 09, 2026
by
thatPepe
Committed by
GitHub
Mar 09, 2026
Browse files
Merge pull request #260 from InfiniTensor/issue/259
issue/259 - add attn backend option to inference server
parents
dfec9d89
91cd2992
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
1 deletion
+22
-1
python/infinilm/llm/llm.py
python/infinilm/llm/llm.py
+9
-0
python/infinilm/server/inference_server.py
python/infinilm/server/inference_server.py
+13
-1
No files found.
python/infinilm/llm/llm.py
View file @
3b8e1cb7
...
...
@@ -55,6 +55,7 @@ class EngineConfig:
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
"""
model_path
:
str
...
...
@@ -71,6 +72,7 @@ class EngineConfig:
top_p
:
float
=
0.8
top_k
:
int
=
1
enable_graph
:
bool
=
False
attn_backend
:
str
=
"default"
class
LLMEngine
:
...
...
@@ -88,6 +90,7 @@ class LLMEngine:
device
=
self
.
device
,
distributed_config
=
DistConfig
(
config
.
tensor_parallel_size
),
enable_graph_compiling
=
config
.
enable_graph
,
attention_backend
=
config
.
attn_backend
,
)
# Load model weights
...
...
@@ -383,6 +386,7 @@ class LLM:
top_p
:
float
=
0.8
,
top_k
:
int
=
1
,
enable_graph
:
bool
=
False
,
attn_backend
:
str
=
"default"
,
):
"""Initialize LLM.
...
...
@@ -401,6 +405,7 @@ class LLM:
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
"""
config
=
EngineConfig
(
model_path
=
model_path
,
...
...
@@ -417,6 +422,7 @@ class LLM:
top_p
=
top_p
,
top_k
=
top_k
,
enable_graph
=
enable_graph
,
attn_backend
=
attn_backend
,
)
self
.
engine
=
LLMEngine
(
config
)
self
.
config
=
config
...
...
@@ -536,6 +542,7 @@ class AsyncLLMEngine:
top_p
:
float
=
0.8
,
top_k
:
int
=
1
,
enable_graph
:
bool
=
False
,
attn_backend
:
str
=
"default"
,
):
"""Initialize AsyncLLMEngine.
...
...
@@ -554,6 +561,7 @@ class AsyncLLMEngine:
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
"""
config
=
EngineConfig
(
model_path
=
model_path
,
...
...
@@ -570,6 +578,7 @@ class AsyncLLMEngine:
top_p
=
top_p
,
top_k
=
top_k
,
enable_graph
=
enable_graph
,
attn_backend
=
attn_backend
,
)
self
.
engine
=
LLMEngine
(
config
)
self
.
config
=
config
...
...
python/infinilm/server/inference_server.py
View file @
3b8e1cb7
...
...
@@ -108,6 +108,7 @@ class InferenceServer:
host
:
str
=
"0.0.0.0"
,
port
:
int
=
8000
,
enable_graph
:
bool
=
False
,
attn_backend
:
str
=
"default"
,
):
"""Initialize inference server.
...
...
@@ -128,6 +129,7 @@ class InferenceServer:
host: Server host address.
port: Server port number.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
"""
self
.
model_path
=
model_path
# vLLM-like served model id: directory name of model_path
...
...
@@ -147,6 +149,7 @@ class InferenceServer:
self
.
host
=
host
self
.
port
=
port
self
.
enable_graph
=
enable_graph
self
.
attn_backend
=
attn_backend
self
.
engine
:
AsyncLLMEngine
=
None
...
...
@@ -177,6 +180,7 @@ class InferenceServer:
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
enable_graph
=
self
.
enable_graph
,
attn_backend
=
self
.
attn_backend
,
)
self
.
engine
.
start
()
logger
.
info
(
f
"Engine initialized with model at
{
self
.
model_path
}
"
)
...
...
@@ -613,6 +617,13 @@ def parse_args():
action
=
"store_true"
,
help
=
"Enable graph compiling"
,
)
parser
.
add_argument
(
"--attn"
,
type
=
str
,
default
=
"default"
,
choices
=
[
"default"
,
"flash-attn"
],
help
=
"Attention backend to use: 'default' or 'flash-attn'"
,
)
parser
.
add_argument
(
"--log_level"
,
type
=
str
,
...
...
@@ -655,7 +666,7 @@ def main():
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
"
\n
"
"Optional: --enable-paged-attn --enable-graph"
"Optional: --enable-paged-attn --enable-graph
--attn=default
"
)
sys
.
exit
(
1
)
...
...
@@ -676,6 +687,7 @@ def main():
host
=
args
.
host
,
port
=
args
.
port
,
enable_graph
=
args
.
enable_graph
,
attn_backend
=
args
.
attn
,
)
server
.
start
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment