Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3196999f
Unverified
Commit
3196999f
authored
Mar 18, 2025
by
Cheng Wan
Committed by
GitHub
Mar 18, 2025
Browse files
Reduce computation and communication in DP attention (#4521)
parent
9e0186f3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
65 additions
and
75 deletions
+65
-75
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+6
-2
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+21
-21
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+2
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+35
-45
test/srt/test_dp_attention.py
test/srt/test_dp_attention.py
+1
-5
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
3196999f
...
@@ -189,6 +189,9 @@ class GroupCoordinator:
...
@@ -189,6 +189,9 @@ class GroupCoordinator:
device_group
:
ProcessGroup
# group for device communication
device_group
:
ProcessGroup
# group for device communication
use_pynccl
:
bool
# a hint of whether to use PyNccl
use_pynccl
:
bool
# a hint of whether to use PyNccl
use_custom_allreduce
:
bool
# a hint of whether to use CustomAllreduce
use_custom_allreduce
:
bool
# a hint of whether to use CustomAllreduce
use_message_queue_broadcaster
:
(
bool
# a hint of whether to use message queue broadcaster
)
# communicators are only created for world size > 1
# communicators are only created for world size > 1
pynccl_comm
:
Optional
[
Any
]
# PyNccl communicator
pynccl_comm
:
Optional
[
Any
]
# PyNccl communicator
ca_comm
:
Optional
[
Any
]
# Custom allreduce communicator
ca_comm
:
Optional
[
Any
]
# Custom allreduce communicator
...
@@ -241,6 +244,7 @@ class GroupCoordinator:
...
@@ -241,6 +244,7 @@ class GroupCoordinator:
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
self
.
use_message_queue_broadcaster
=
use_message_queue_broadcaster
# lazy import to avoid documentation build error
# lazy import to avoid documentation build error
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
...
@@ -269,7 +273,7 @@ class GroupCoordinator:
...
@@ -269,7 +273,7 @@ class GroupCoordinator:
HpuCommunicator
,
HpuCommunicator
,
)
)
self
.
hpu_communicator
:
Optional
[
HpuCommunicator
]
self
.
hpu_communicator
:
Optional
[
HpuCommunicator
]
=
None
if
use_hpu_communicator
and
self
.
world_size
>
1
:
if
use_hpu_communicator
and
self
.
world_size
>
1
:
self
.
hpu_communicator
=
HpuCommunicator
(
group
=
self
.
device_group
)
self
.
hpu_communicator
=
HpuCommunicator
(
group
=
self
.
device_group
)
...
@@ -277,7 +281,7 @@ class GroupCoordinator:
...
@@ -277,7 +281,7 @@ class GroupCoordinator:
XpuCommunicator
,
XpuCommunicator
,
)
)
self
.
xpu_communicator
:
Optional
[
XpuCommunicator
]
self
.
xpu_communicator
:
Optional
[
XpuCommunicator
]
=
None
if
use_xpu_communicator
and
self
.
world_size
>
1
:
if
use_xpu_communicator
and
self
.
world_size
>
1
:
self
.
xpu_communicator
=
XpuCommunicator
(
group
=
self
.
device_group
)
self
.
xpu_communicator
=
XpuCommunicator
(
group
=
self
.
device_group
)
...
...
python/sglang/srt/layers/dp_attention.py
View file @
3196999f
...
@@ -53,10 +53,8 @@ def initialize_dp_attention(
...
@@ -53,10 +53,8 @@ def initialize_dp_attention(
)
)
if
enable_dp_attention
:
if
enable_dp_attention
:
local_rank
=
tp_rank
%
(
tp_size
//
dp_size
)
_DP_SIZE
=
dp_size
_DP_SIZE
=
dp_size
else
:
else
:
local_rank
=
tp_rank
_DP_SIZE
=
1
_DP_SIZE
=
1
tp_group
=
get_tp_group
()
tp_group
=
get_tp_group
()
...
@@ -65,7 +63,7 @@ def initialize_dp_attention(
...
@@ -65,7 +63,7 @@ def initialize_dp_attention(
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
for
head
in
range
(
0
,
tp_size
,
_ATTN_TP_SIZE
)
for
head
in
range
(
0
,
tp_size
,
_ATTN_TP_SIZE
)
],
],
local_rank
,
tp_group
.
local_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
SYNC_TOKEN_IDS_ACROSS_TP
,
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
False
,
...
@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
...
@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel
[
grid
](
dst
,
src
,
offset
,
sz
,
offset_src
,
chunk_size
,
BLOCK_SIZE
)
memcpy_triton_kernel
[
grid
](
dst
,
src
,
offset
,
sz
,
offset_src
,
chunk_size
,
BLOCK_SIZE
)
def
dp_gather
(
def
_
dp_gather
(
global_tokens
:
torch
.
Tensor
,
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
layer_id
:
Union
[
str
,
int
]
,
is_partial
:
bool
,
):
):
local_start_pos
,
local_num_tokens
=
get_dp_local_info
(
forward_batch
)
local_start_pos
,
local_num_tokens
=
get_dp_local_info
(
forward_batch
)
global_tokens
.
fill_
(
0
)
global_tokens
.
fill_
(
0
)
assert
local_tokens
.
is_contiguous
()
assert
local_tokens
.
is_contiguous
()
assert
global_tokens
.
is_contiguous
()
assert
global_tokens
.
is_contiguous
()
if
local_tokens
.
shape
[
0
]
>
0
and
(
layer_id
!=
"embedding"
or
get_attention_tp_rank
()
==
0
if
local_tokens
.
shape
[
0
]
>
0
and
(
is_partial
or
get_attention_tp_rank
()
==
0
):
):
assert
(
assert
(
global_tokens
.
untyped_storage
().
data_ptr
()
global_tokens
.
untyped_storage
().
data_ptr
()
!=
local_tokens
.
untyped_storage
().
data_ptr
()
!=
local_tokens
.
untyped_storage
().
data_ptr
()
...
@@ -216,6 +213,22 @@ def dp_gather(
...
@@ -216,6 +213,22 @@ def dp_gather(
global_tokens
[:]
=
tensor_model_parallel_all_reduce
(
global_tokens
)
global_tokens
[:]
=
tensor_model_parallel_all_reduce
(
global_tokens
)
def
dp_gather_partial
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
_dp_gather
(
global_tokens
,
local_tokens
,
forward_batch
,
is_partial
=
True
)
def
dp_gather_replicate
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
_dp_gather
(
global_tokens
,
local_tokens
,
forward_batch
,
is_partial
=
False
)
def
dp_scatter
(
def
dp_scatter
(
local_tokens
:
torch
.
Tensor
,
# output
local_tokens
:
torch
.
Tensor
,
# output
global_tokens
:
torch
.
Tensor
,
# input
global_tokens
:
torch
.
Tensor
,
# input
...
@@ -236,16 +249,3 @@ def dp_scatter(
...
@@ -236,16 +249,3 @@ def dp_scatter(
memcpy_triton
(
memcpy_triton
(
local_tokens
,
global_tokens
,
0
,
local_start_pos
,
local_num_tokens
,
True
local_tokens
,
global_tokens
,
0
,
local_start_pos
,
local_num_tokens
,
True
)
)
def
get_do_logits_dp_scatter
(
forward_batch
:
ForwardBatch
):
def
do_logits_dp_scatter
(
logits
:
torch
.
Tensor
):
local_logits
=
torch
.
empty
(
(
forward_batch
.
input_ids
.
shape
[
0
],
*
logits
.
shape
[
1
:]),
dtype
=
logits
.
dtype
,
device
=
logits
.
device
,
)
dp_scatter
(
local_logits
,
logits
,
forward_batch
)
return
local_logits
return
do_logits_dp_scatter
python/sglang/srt/layers/logits_processor.py
View file @
3196999f
...
@@ -28,7 +28,7 @@ from sglang.srt.distributed import (
...
@@ -28,7 +28,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
dp_gather
,
dp_gather
_replicate
,
dp_scatter
,
dp_scatter
,
get_attention_dp_rank
,
get_attention_dp_rank
,
get_attention_dp_size
,
get_attention_dp_size
,
...
@@ -428,7 +428,7 @@ class LogitsProcessor(nn.Module):
...
@@ -428,7 +428,7 @@ class LogitsProcessor(nn.Module):
logits_metadata
.
gathered_buffer
,
logits_metadata
.
gathered_buffer
,
hidden_states
.
clone
(),
hidden_states
.
clone
(),
)
)
dp_gather
(
hidden_states
,
local_hidden_states
,
logits_metadata
,
"embedding"
)
dp_gather
_replicate
(
hidden_states
,
local_hidden_states
,
logits_metadata
)
if
hasattr
(
lm_head
,
"weight"
):
if
hasattr
(
lm_head
,
"weight"
):
logits
=
torch
.
matmul
(
logits
=
torch
.
matmul
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
3196999f
...
@@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
...
@@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope
,
decode_attention_fwd_grouped_rope
,
)
)
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
dp_gather
,
dp_gather
_partial
,
dp_scatter
,
dp_scatter
,
get_attention_dp_size
,
get_attention_dp_size
,
get_attention_tp_rank
,
get_attention_tp_rank
,
...
@@ -939,47 +939,58 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -939,47 +939,58 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
residual
is
None
:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Scatter
# Self Attention
if
self
.
dp_size
!=
1
:
hidden_states
=
self
.
self_attn
(
# important: forward batch.gathered_buffer is used both after scatter and after gather.
positions
=
positions
,
# be careful about this!
hidden_states
=
hidden_states
,
hidden_states
,
global_hidden_states
=
(
forward_batch
=
forward_batch
,
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
# Gather
# Gather
if
get_tensor_model_parallel_world_size
()
>
1
:
if
get_tensor_model_parallel_world_size
()
>
1
:
# all gather and all reduce
# all gather and all reduce
if
self
.
dp_size
!=
1
:
if
self
.
dp_size
!=
1
:
if
get_attention_tp_rank
()
==
0
:
hidden_states
+=
residual
hidden_states
,
local_hidden_states
=
(
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
forward_batch
.
gathered_buffer
,
hidden_states
,
hidden_states
,
)
)
dp_gather
(
dp_gather
_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
hidden
_s
t
at
es
,
local_
hidden_states
,
forward_batch
,
self
.
layer_id
dp
_s
c
at
ter
(
residual
,
hidden_states
,
forward_batch
)
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
else
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
# Fully Connected
# Fully Connected
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
# Scatter
if
self
.
dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module):
...
@@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Gather
if
self
.
dp_size
!=
1
:
input_ids
,
local_input_ids
=
(
torch
.
empty
(
(
forward_batch
.
gathered_buffer
.
shape
[
0
],),
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
,
),
input_ids
,
)
dp_gather
(
input_ids
,
local_input_ids
,
forward_batch
,
"embedding"
)
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
else
:
...
@@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
self
.
dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
...
...
test/srt/test_dp_attention.py
View file @
3196999f
...
@@ -11,7 +11,7 @@ from sglang.test.test_utils import (
...
@@ -11,7 +11,7 @@ from sglang.test.test_utils import (
)
)
class
TestDPAttention
(
unittest
.
TestCase
):
class
TestDPAttention
DP2TP2
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
...
@@ -59,7 +59,3 @@ class TestDPAttention(unittest.TestCase):
...
@@ -59,7 +59,3 @@ class TestDPAttention(unittest.TestCase):
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"score"
],
0.8
)
self
.
assertGreater
(
metrics
[
"score"
],
0.8
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment