"tests/pytorch/sparse/test_elementwise_op_sp.py" did not exist on "354a211038b669b14e7fa0d1519577996ccaf300"
Unverified Commit 297d3745 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

support qwen3_next blackwell (#10403)

parent 31e9d3a5
......@@ -80,7 +80,13 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
if model_runner.is_hybrid_gdn:
# For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else:
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
-1
]
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.device_core_count = get_device_core_count(model_runner.gpu_id)
......
......@@ -728,6 +728,9 @@ class HybridLinearKVPool(KVCache):
layer_id_override=layer_id,
)
def get_v_head_dim(self):
return self.full_kv_pool.get_value_buffer(0).shape[-1]
class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers."""
......
......@@ -127,6 +127,7 @@ from sglang.srt.utils import (
get_bool_env_var,
get_cpu_ids_by_node,
init_custom_process_group,
is_blackwell,
is_fa3_default_architecture,
is_flashinfer_available,
is_hip,
......@@ -1832,6 +1833,10 @@ class ModelRunner:
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
full_attn_backend = AscendAttnBackend(self)
elif is_blackwell():
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
full_attn_backend = TritonAttnBackend(self)
else:
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
......
......@@ -48,6 +48,7 @@ from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
get_bool_env_var,
is_blackwell,
is_cuda,
next_power_of_2,
)
......@@ -214,7 +215,11 @@ class EAGLEWorker(TpModelWorker):
"triton": self._create_triton_decode_backend,
"aiter": self._create_aiter_decode_backend,
"fa3": self._create_fa3_decode_backend,
"hybrid_linear_attn": self._create_fa3_decode_backend,
"hybrid_linear_attn": (
self._create_fa3_decode_backend
if not is_blackwell()
else self._create_triton_decode_backend
),
"flashmla": self._create_flashmla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
......@@ -232,7 +237,11 @@ class EAGLEWorker(TpModelWorker):
"triton": self._create_triton_prefill_backend,
"aiter": self._create_aiter_prefill_backend,
"fa3": self._create_fa3_prefill_backend,
"hybrid_linear_attn": self._create_fa3_prefill_backend,
"hybrid_linear_attn": (
self._create_fa3_prefill_backend
if not is_blackwell()
else self._create_triton_prefill_backend
),
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment