Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
297d3745
"tests/pytorch/sparse/test_elementwise_op_sp.py" did not exist on "354a211038b669b14e7fa0d1519577996ccaf300"
Unverified
Commit
297d3745
authored
Sep 13, 2025
by
Yi Zhang
Committed by
GitHub
Sep 13, 2025
Browse files
support qwen3_next blackwell (#10403)
parent
31e9d3a5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
3 deletions
+26
-3
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+7
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+3
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+11
-2
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
297d3745
...
...
@@ -80,7 +80,13 @@ class TritonAttnBackend(AttentionBackend):
self
.
num_kv_head
=
model_runner
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()
)
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
if
model_runner
.
is_hybrid_gdn
:
# For hybrid linear models, layer_id = 0 may not be full attention
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_v_head_dim
()
else
:
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
device
=
model_runner
.
device
self
.
device_core_count
=
get_device_core_count
(
model_runner
.
gpu_id
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
297d3745
...
...
@@ -728,6 +728,9 @@ class HybridLinearKVPool(KVCache):
layer_id_override
=
layer_id
,
)
def
get_v_head_dim
(
self
):
return
self
.
full_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
class
SWAKVPool
(
KVCache
):
"""KV cache with separate pools for full and SWA attention layers."""
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
297d3745
...
...
@@ -127,6 +127,7 @@ from sglang.srt.utils import (
get_bool_env_var
,
get_cpu_ids_by_node
,
init_custom_process_group
,
is_blackwell
,
is_fa3_default_architecture
,
is_flashinfer_available
,
is_hip
,
...
...
@@ -1832,6 +1833,10 @@ class ModelRunner:
from
sglang.srt.layers.attention.ascend_backend
import
AscendAttnBackend
full_attn_backend
=
AscendAttnBackend
(
self
)
elif
is_blackwell
():
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
full_attn_backend
=
TritonAttnBackend
(
self
)
else
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
297d3745
...
...
@@ -48,6 +48,7 @@ from sglang.srt.utils import (
empty_context
,
get_available_gpu_memory
,
get_bool_env_var
,
is_blackwell
,
is_cuda
,
next_power_of_2
,
)
...
...
@@ -214,7 +215,11 @@ class EAGLEWorker(TpModelWorker):
"triton"
:
self
.
_create_triton_decode_backend
,
"aiter"
:
self
.
_create_aiter_decode_backend
,
"fa3"
:
self
.
_create_fa3_decode_backend
,
"hybrid_linear_attn"
:
self
.
_create_fa3_decode_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_decode_backend
if
not
is_blackwell
()
else
self
.
_create_triton_decode_backend
),
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
...
...
@@ -232,7 +237,11 @@ class EAGLEWorker(TpModelWorker):
"triton"
:
self
.
_create_triton_prefill_backend
,
"aiter"
:
self
.
_create_aiter_prefill_backend
,
"fa3"
:
self
.
_create_fa3_prefill_backend
,
"hybrid_linear_attn"
:
self
.
_create_fa3_prefill_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_prefill_backend
if
not
is_blackwell
()
else
self
.
_create_triton_prefill_backend
),
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment