Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3196999f
Unverified
Commit
3196999f
authored
Mar 18, 2025
by
Cheng Wan
Committed by
GitHub
Mar 18, 2025
Browse files
Reduce computation and communication in DP attention (#4521)
parent
9e0186f3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
65 additions
and
75 deletions
+65
-75
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+6
-2
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+21
-21
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+2
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+35
-45
test/srt/test_dp_attention.py
test/srt/test_dp_attention.py
+1
-5
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
3196999f
...
...
@@ -189,6 +189,9 @@ class GroupCoordinator:
device_group
:
ProcessGroup
# group for device communication
use_pynccl
:
bool
# a hint of whether to use PyNccl
use_custom_allreduce
:
bool
# a hint of whether to use CustomAllreduce
use_message_queue_broadcaster
:
(
bool
# a hint of whether to use message queue broadcaster
)
# communicators are only created for world size > 1
pynccl_comm
:
Optional
[
Any
]
# PyNccl communicator
ca_comm
:
Optional
[
Any
]
# Custom allreduce communicator
...
...
@@ -241,6 +244,7 @@ class GroupCoordinator:
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
self
.
use_message_queue_broadcaster
=
use_message_queue_broadcaster
# lazy import to avoid documentation build error
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
...
...
@@ -269,7 +273,7 @@ class GroupCoordinator:
HpuCommunicator
,
)
self
.
hpu_communicator
:
Optional
[
HpuCommunicator
]
self
.
hpu_communicator
:
Optional
[
HpuCommunicator
]
=
None
if
use_hpu_communicator
and
self
.
world_size
>
1
:
self
.
hpu_communicator
=
HpuCommunicator
(
group
=
self
.
device_group
)
...
...
@@ -277,7 +281,7 @@ class GroupCoordinator:
XpuCommunicator
,
)
self
.
xpu_communicator
:
Optional
[
XpuCommunicator
]
self
.
xpu_communicator
:
Optional
[
XpuCommunicator
]
=
None
if
use_xpu_communicator
and
self
.
world_size
>
1
:
self
.
xpu_communicator
=
XpuCommunicator
(
group
=
self
.
device_group
)
...
...
python/sglang/srt/layers/dp_attention.py
View file @
3196999f
...
...
@@ -53,10 +53,8 @@ def initialize_dp_attention(
)
if
enable_dp_attention
:
local_rank
=
tp_rank
%
(
tp_size
//
dp_size
)
_DP_SIZE
=
dp_size
else
:
local_rank
=
tp_rank
_DP_SIZE
=
1
tp_group
=
get_tp_group
()
...
...
@@ -65,7 +63,7 @@ def initialize_dp_attention(
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
for
head
in
range
(
0
,
tp_size
,
_ATTN_TP_SIZE
)
],
local_rank
,
tp_group
.
local_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
...
...
@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel
[
grid
](
dst
,
src
,
offset
,
sz
,
offset_src
,
chunk_size
,
BLOCK_SIZE
)
def
dp_gather
(
def
_
dp_gather
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
Union
[
str
,
int
]
,
is_partial
:
bool
,
):
local_start_pos
,
local_num_tokens
=
get_dp_local_info
(
forward_batch
)
global_tokens
.
fill_
(
0
)
assert
local_tokens
.
is_contiguous
()
assert
global_tokens
.
is_contiguous
()
if
local_tokens
.
shape
[
0
]
>
0
and
(
layer_id
!=
"embedding"
or
get_attention_tp_rank
()
==
0
):
if
local_tokens
.
shape
[
0
]
>
0
and
(
is_partial
or
get_attention_tp_rank
()
==
0
):
assert
(
global_tokens
.
untyped_storage
().
data_ptr
()
!=
local_tokens
.
untyped_storage
().
data_ptr
()
...
...
@@ -216,6 +213,22 @@ def dp_gather(
global_tokens
[:]
=
tensor_model_parallel_all_reduce
(
global_tokens
)
def
dp_gather_partial
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
_dp_gather
(
global_tokens
,
local_tokens
,
forward_batch
,
is_partial
=
True
)
def
dp_gather_replicate
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
_dp_gather
(
global_tokens
,
local_tokens
,
forward_batch
,
is_partial
=
False
)
def
dp_scatter
(
local_tokens
:
torch
.
Tensor
,
# output
global_tokens
:
torch
.
Tensor
,
# input
...
...
@@ -236,16 +249,3 @@ def dp_scatter(
memcpy_triton
(
local_tokens
,
global_tokens
,
0
,
local_start_pos
,
local_num_tokens
,
True
)
def
get_do_logits_dp_scatter
(
forward_batch
:
ForwardBatch
):
def
do_logits_dp_scatter
(
logits
:
torch
.
Tensor
):
local_logits
=
torch
.
empty
(
(
forward_batch
.
input_ids
.
shape
[
0
],
*
logits
.
shape
[
1
:]),
dtype
=
logits
.
dtype
,
device
=
logits
.
device
,
)
dp_scatter
(
local_logits
,
logits
,
forward_batch
)
return
local_logits
return
do_logits_dp_scatter
python/sglang/srt/layers/logits_processor.py
View file @
3196999f
...
...
@@ -28,7 +28,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather
,
)
from
sglang.srt.layers.dp_attention
import
(
dp_gather
,
dp_gather
_replicate
,
dp_scatter
,
get_attention_dp_rank
,
get_attention_dp_size
,
...
...
@@ -428,7 +428,7 @@ class LogitsProcessor(nn.Module):
logits_metadata
.
gathered_buffer
,
hidden_states
.
clone
(),
)
dp_gather
(
hidden_states
,
local_hidden_states
,
logits_metadata
,
"embedding"
)
dp_gather
_replicate
(
hidden_states
,
local_hidden_states
,
logits_metadata
)
if
hasattr
(
lm_head
,
"weight"
):
logits
=
torch
.
matmul
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
3196999f
...
...
@@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope
,
)
from
sglang.srt.layers.dp_attention
import
(
dp_gather
,
dp_gather
_partial
,
dp_scatter
,
get_attention_dp_size
,
get_attention_tp_rank
,
...
...
@@ -939,47 +939,58 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
residual
is
None
:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Scatter
if
self
.
dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
# Gather
if
get_tensor_model_parallel_world_size
()
>
1
:
# all gather and all reduce
if
self
.
dp_size
!=
1
:
if
get_attention_tp_rank
()
==
0
:
hidden_states
+=
residual
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
hidden_states
,
)
dp_gather
(
hidden
_s
t
at
es
,
local_
hidden_states
,
forward_batch
,
self
.
layer_id
)
dp_gather
_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
dp
_s
c
at
ter
(
residual
,
hidden_states
,
forward_batch
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
else
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
# Fully Connected
hidden_states
=
self
.
mlp
(
hidden_states
)
# Scatter
if
self
.
dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
hidden_states
,
residual
...
...
@@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
# Gather
if
self
.
dp_size
!=
1
:
input_ids
,
local_input_ids
=
(
torch
.
empty
(
(
forward_batch
.
gathered_buffer
.
shape
[
0
],),
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
,
),
input_ids
,
)
dp_gather
(
input_ids
,
local_input_ids
,
forward_batch
,
"embedding"
)
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
...
...
@@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
self
.
dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
...
...
test/srt/test_dp_attention.py
View file @
3196999f
...
...
@@ -11,7 +11,7 @@ from sglang.test.test_utils import (
)
class
TestDPAttention
(
unittest
.
TestCase
):
class
TestDPAttention
DP2TP2
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
...
...
@@ -59,7 +59,3 @@ class TestDPAttention(unittest.TestCase):
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"score"
],
0.8
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment