Commit 471309e2 authored by wooway777's avatar wooway777
Browse files

issue/248 - support attn backend in front end and update readme

parent 0ea1cd55
......@@ -160,3 +160,21 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra --backend cpp --ndev 1 --cache_dir ~/.cache/huggingface/datasets/
```
> 注意:`--cache_dir` 应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录,而不是直接指向这些子目录
- 试验中功能
- Warm Up
```bash
python examples/bench.py --nvidia --model=<model-path> --warmup
```
- Paged Attention
```bash
python examples/bench.py --nvidia --model=<model-path> --enable-paged-attn
```
- CUDA Graph
```bash
python examples/bench.py --nvidia --model=<model-path> --enable-paged-attn --enable-graph
```
- 选择attention后端 (使用flash attention后端需要先在InfiniCore完成相关配置和编译)
```bash
python examples/bench.py --nvidia --model=<model-path> --enable-paged-attn [--attn=flash-attn | --attn=default]
```
......@@ -252,6 +252,13 @@ def get_args():
action="store_true",
help="Perform a warmup run before benchmarking/inference.",
)
parser.add_argument(
"--attn",
type=str,
default="flash-attn",
choices=["default", "flash-attn"],
help="attention backend to use: 'default' or 'flash-attn'",
)
return parser.parse_args()
......@@ -278,6 +285,7 @@ class TestModel:
skip_load=False,
cache_config=None,
enable_graph=False,
attn_backend="flash-attn",
) -> None:
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
......@@ -289,6 +297,7 @@ class TestModel:
distributed_config=DistConfig(tp),
cache_config=cache_config,
enable_graph_compiling=enable_graph,
attention_backend=attn_backend,
)
# ---------------------------------------------------------------------------- #
......@@ -461,6 +470,7 @@ if __name__ == "__main__":
skip_load=skip_load,
cache_config=cache_config,
enable_graph=enable_graph,
attn_backend=args.attn,
)
# ---------------------------------------------------------------------------- #
......
......@@ -142,6 +142,14 @@ def get_args():
help="sampling temperature",
)
parser.add_argument(
"--attn",
type=str,
default="flash-attn",
choices=["default", "flash-attn"],
help="attention backend to use: 'default' or 'flash-attn'",
)
return parser.parse_args()
......@@ -156,6 +164,7 @@ def test(
top_k=1,
top_p=1.0,
temperature=1.0,
attn_backend="flash-attn",
):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
......@@ -166,6 +175,7 @@ def test(
device=infini_device,
distributed_config=DistConfig(tp),
enable_graph_compiling=enable_graph,
attention_backend=attn_backend,
)
# ---------------------------------------------------------------------------- #
# Load Weights
......@@ -333,4 +343,5 @@ if __name__ == "__main__":
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
attn_backend=args.attn,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment