Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
646cef2e
Unverified
Commit
646cef2e
authored
Jul 04, 2025
by
Yi Zhang
Committed by
GitHub
Jul 03, 2025
Browse files
support qwen3 dense model dp attention (#7681)
parent
1dce6c48
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
17 deletions
+49
-17
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+7
-1
python/sglang/srt/models/qwen3.py
python/sglang/srt/models/qwen3.py
+42
-16
No files found.
python/sglang/srt/models/qwen2.py
View file @
646cef2e
...
...
@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
...
...
@@ -264,6 +265,7 @@ class Qwen2Model(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
else
:
...
...
@@ -331,6 +333,10 @@ class Qwen2Model(nn.Module):
"residual"
:
residual
,
}
)
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
python/sglang/srt/models/qwen3.py
View file @
646cef2e
...
...
@@ -14,6 +14,8 @@ from sglang.srt.distributed import (
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
)
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
...
...
@@ -54,18 +56,21 @@ class Qwen3Attention(nn.Module):
self
.
hidden_size
=
hidden_size
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
self
.
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
self
.
tp_size
attn_tp_rank
=
get_attention_tp_rank
()
attn_tp_size
=
get_attention_tp_size
()
assert
self
.
total_num_heads
%
attn_tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
attn_tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
self
.
tp_size
:
if
self
.
total_num_kv_heads
>=
attn_
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
self
.
tp_size
==
0
assert
self
.
total_num_kv_heads
%
attn_
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
self
.
tp_size
)
assert
attn_
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
attn_
tp_size
)
self
.
head_dim
=
head_dim
or
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
...
...
@@ -84,6 +89,8 @@ class Qwen3Attention(nn.Module):
self
.
total_num_kv_heads
,
bias
=
attention_bias
,
quant_config
=
quant_config
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
)
self
.
o_proj
=
RowParallelLinear
(
...
...
@@ -91,6 +98,9 @@ class Qwen3Attention(nn.Module):
hidden_size
,
bias
=
attention_bias
,
quant_config
=
quant_config
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
reduce_results
=
False
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
)
...
...
@@ -176,6 +186,18 @@ class Qwen3DecoderLayer(nn.Module):
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
layer_id
=
layer_id
,
num_layers
=
config
.
num_hidden_layers
,
is_layer_sparse
=
False
,
is_previous_layer_sparse
=
False
,
)
self
.
layer_communicator
=
LayerCommunicator
(
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -184,11 +206,10 @@ class Qwen3DecoderLayer(nn.Module):
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
@@ -196,8 +217,13 @@ class Qwen3DecoderLayer(nn.Module):
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_mlp
(
hidden_states
,
residual
,
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
)
return
hidden_states
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment