"docs/vscode:/vscode.git/clone" did not exist on "acb955e15ff84c07d77ba27c839dcdfdc82e79e1"
Unverified Commit e8100774 authored by Dom Brown's avatar Dom Brown Committed by GitHub
Browse files

Allow use of TRTLLM_MHA backend for hybrid attention on Blackwell (#11138)

parent 963175d5
...@@ -178,7 +178,8 @@ def attn_backend_wrapper(runner, full_attn_backend): ...@@ -178,7 +178,8 @@ def attn_backend_wrapper(runner, full_attn_backend):
if is_blackwell(): if is_blackwell():
assert ( assert (
runner.server_args.attention_backend == "triton" runner.server_args.attention_backend == "triton"
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend." or runner.server_args.attention_backend == "trtllm_mha"
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
if is_npu(): if is_npu():
assert ( assert (
runner.server_args.attention_backend == "ascend" runner.server_args.attention_backend == "ascend"
......
...@@ -1620,7 +1620,7 @@ class ModelRunner: ...@@ -1620,7 +1620,7 @@ class ModelRunner:
) )
elif self.is_hybrid_gdn: elif self.is_hybrid_gdn:
self.token_to_kv_pool = HybridLinearKVPool( self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size if _is_npu else 1, page_size=self.page_size,
size=self.max_total_num_tokens, size=self.max_total_num_tokens,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads( head_num=self.model_config.get_num_kv_heads(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment