Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
471309e2
Commit
471309e2
authored
Mar 05, 2026
by
wooway777
Browse files
issue/248 - support attn backend in front end and update readme
parent
0ea1cd55
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
0 deletions
+39
-0
README.md
README.md
+18
-0
examples/bench.py
examples/bench.py
+10
-0
examples/jiuge.py
examples/jiuge.py
+11
-0
No files found.
README.md
View file @
471309e2
...
@@ -160,3 +160,21 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
...
@@ -160,3 +160,21 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra --backend cpp --ndev 1 --cache_dir ~/.cache/huggingface/datasets/
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra --backend cpp --ndev 1 --cache_dir ~/.cache/huggingface/datasets/
```
```
> 注意:`--cache_dir` 应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录,而不是直接指向这些子目录
> 注意:`--cache_dir` 应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录,而不是直接指向这些子目录
- 试验中功能
- Warm Up
```
bash
python examples/bench.py --nvidia --model=
<model-path>
--warmup
```
- Paged Attention
```
bash
python examples/bench.py --nvidia --model=
<model-path>
--enable-paged-attn
```
- CUDA Graph
```
bash
python examples/bench.py --nvidia --model=
<model-path>
--enable-paged-attn --enable-graph
```
- 选择attention后端 (使用flash attention后端需要先在InfiniCore完成相关配置和编译)
```
bash
python examples/bench.py --nvidia --model=
<model-path>
--enable-paged-attn [--attn=flash-attn | --attn=default]
```
examples/bench.py
View file @
471309e2
...
@@ -252,6 +252,13 @@ def get_args():
...
@@ -252,6 +252,13 @@ def get_args():
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Perform a warmup run before benchmarking/inference."
,
help
=
"Perform a warmup run before benchmarking/inference."
,
)
)
parser
.
add_argument
(
"--attn"
,
type
=
str
,
default
=
"flash-attn"
,
choices
=
[
"default"
,
"flash-attn"
],
help
=
"attention backend to use: 'default' or 'flash-attn'"
,
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -278,6 +285,7 @@ class TestModel:
...
@@ -278,6 +285,7 @@ class TestModel:
skip_load
=
False
,
skip_load
=
False
,
cache_config
=
None
,
cache_config
=
None
,
enable_graph
=
False
,
enable_graph
=
False
,
attn_backend
=
"flash-attn"
,
)
->
None
:
)
->
None
:
model_path
=
os
.
path
.
expanduser
(
model_path
)
model_path
=
os
.
path
.
expanduser
(
model_path
)
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
...
@@ -289,6 +297,7 @@ class TestModel:
...
@@ -289,6 +297,7 @@ class TestModel:
distributed_config
=
DistConfig
(
tp
),
distributed_config
=
DistConfig
(
tp
),
cache_config
=
cache_config
,
cache_config
=
cache_config
,
enable_graph_compiling
=
enable_graph
,
enable_graph_compiling
=
enable_graph
,
attention_backend
=
attn_backend
,
)
)
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
...
@@ -461,6 +470,7 @@ if __name__ == "__main__":
...
@@ -461,6 +470,7 @@ if __name__ == "__main__":
skip_load
=
skip_load
,
skip_load
=
skip_load
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
enable_graph
=
enable_graph
,
enable_graph
=
enable_graph
,
attn_backend
=
args
.
attn
,
)
)
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
...
...
examples/jiuge.py
View file @
471309e2
...
@@ -142,6 +142,14 @@ def get_args():
...
@@ -142,6 +142,14 @@ def get_args():
help
=
"sampling temperature"
,
help
=
"sampling temperature"
,
)
)
parser
.
add_argument
(
"--attn"
,
type
=
str
,
default
=
"flash-attn"
,
choices
=
[
"default"
,
"flash-attn"
],
help
=
"attention backend to use: 'default' or 'flash-attn'"
,
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -156,6 +164,7 @@ def test(
...
@@ -156,6 +164,7 @@ def test(
top_k
=
1
,
top_k
=
1
,
top_p
=
1.0
,
top_p
=
1.0
,
temperature
=
1.0
,
temperature
=
1.0
,
attn_backend
=
"flash-attn"
,
):
):
model_path
=
os
.
path
.
expanduser
(
model_path
)
model_path
=
os
.
path
.
expanduser
(
model_path
)
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
...
@@ -166,6 +175,7 @@ def test(
...
@@ -166,6 +175,7 @@ def test(
device
=
infini_device
,
device
=
infini_device
,
distributed_config
=
DistConfig
(
tp
),
distributed_config
=
DistConfig
(
tp
),
enable_graph_compiling
=
enable_graph
,
enable_graph_compiling
=
enable_graph
,
attention_backend
=
attn_backend
,
)
)
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# Load Weights
# Load Weights
...
@@ -333,4 +343,5 @@ if __name__ == "__main__":
...
@@ -333,4 +343,5 @@ if __name__ == "__main__":
top_k
=
args
.
top_k
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
,
top_p
=
args
.
top_p
,
temperature
=
args
.
temperature
,
temperature
=
args
.
temperature
,
attn_backend
=
args
.
attn
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment