Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3196999f
"vscode:/vscode.git/clone" did not exist on "aba7fefce7f7b866e62403c4c4bb1354af32031c"
Unverified
Commit
3196999f
authored
Mar 18, 2025
by
Cheng Wan
Committed by
GitHub
Mar 18, 2025
Browse files
Reduce computation and communication in DP attention (#4521)
parent
9e0186f3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
65 additions
and
75 deletions
+65
-75
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+6
-2
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+21
-21
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+2
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+35
-45
test/srt/test_dp_attention.py
test/srt/test_dp_attention.py
+1
-5
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
3196999f
...
...
@@ -189,6 +189,9 @@ class GroupCoordinator:
device_group
:
ProcessGroup
# group for device communication
use_pynccl
:
bool
# a hint of whether to use PyNccl
use_custom_allreduce
:
bool
# a hint of whether to use CustomAllreduce
use_message_queue_broadcaster
:
(
bool
# a hint of whether to use message queue broadcaster
)
# communicators are only created for world size > 1
pynccl_comm
:
Optional
[
Any
]
# PyNccl communicator
ca_comm
:
Optional
[
Any
]
# Custom allreduce communicator
...
...
@@ -241,6 +244,7 @@ class GroupCoordinator:
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
self
.
use_message_queue_broadcaster
=
use_message_queue_broadcaster
# lazy import to avoid documentation build error
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
...
...
@@ -269,7 +273,7 @@ class GroupCoordinator:
HpuCommunicator
,
)
self
.
hpu_communicator
:
Optional
[
HpuCommunicator
]
self
.
hpu_communicator
:
Optional
[
HpuCommunicator
]
=
None
if
use_hpu_communicator
and
self
.
world_size
>
1
:
self
.
hpu_communicator
=
HpuCommunicator
(
group
=
self
.
device_group
)
...
...
@@ -277,7 +281,7 @@ class GroupCoordinator:
XpuCommunicator
,
)
self
.
xpu_communicator
:
Optional
[
XpuCommunicator
]
self
.
xpu_communicator
:
Optional
[
XpuCommunicator
]
=
None
if
use_xpu_communicator
and
self
.
world_size
>
1
:
self
.
xpu_communicator
=
XpuCommunicator
(
group
=
self
.
device_group
)
...
...
python/sglang/srt/layers/dp_attention.py
View file @
3196999f
...
...
@@ -53,10 +53,8 @@ def initialize_dp_attention(
)
if
enable_dp_attention
:
local_rank
=
tp_rank
%
(
tp_size
//
dp_size
)
_DP_SIZE
=
dp_size
else
:
local_rank
=
tp_rank
_DP_SIZE
=
1
tp_group
=
get_tp_group
()
...
...
@@ -65,7 +63,7 @@ def initialize_dp_attention(
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
for
head
in
range
(
0
,
tp_size
,
_ATTN_TP_SIZE
)
],
local_rank
,
tp_group
.
local_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
...
...
@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel
[
grid
](
dst
,
src
,
offset
,
sz
,
offset_src
,
chunk_size
,
BLOCK_SIZE
)
def
dp_gather
(
def
_
dp_gather
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
Union
[
str
,
int
]
,
is_partial
:
bool
,
):
local_start_pos
,
local_num_tokens
=
get_dp_local_info
(
forward_batch
)
global_tokens
.
fill_
(
0
)
assert
local_tokens
.
is_contiguous
()
assert
global_tokens
.
is_contiguous
()
if
local_tokens
.
shape
[
0
]
>
0
and
(
layer_id
!=
"embedding"
or
get_attention_tp_rank
()
==
0
):
if
local_tokens
.
shape
[
0
]
>
0
and
(
is_partial
or
get_attention_tp_rank
()
==
0
):
assert
(
global_tokens
.
untyped_storage
().
data_ptr
()
!=
local_tokens
.
untyped_storage
().
data_ptr
()
...
...
@@ -216,6 +213,22 @@ def dp_gather(
global_tokens
[:]
=
tensor_model_parallel_all_reduce
(
global_tokens
)
def
dp_gather_partial
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
_dp_gather
(
global_tokens
,
local_tokens
,
forward_batch
,
is_partial
=
True
)
def
dp_gather_replicate
(
global_tokens
:
torch
.
Tensor
,
local_tokens
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
_dp_gather
(
global_tokens
,
local_tokens
,
forward_batch
,
is_partial
=
False
)
def
dp_scatter
(
local_tokens
:
torch
.
Tensor
,
# output
global_tokens
:
torch
.
Tensor
,
# input
...
...
@@ -236,16 +249,3 @@ def dp_scatter(
memcpy_triton
(
local_tokens
,
global_tokens
,
0
,
local_start_pos
,
local_num_tokens
,
True
)
def
get_do_logits_dp_scatter
(
forward_batch
:
ForwardBatch
):
def
do_logits_dp_scatter
(
logits
:
torch
.
Tensor
):
local_logits
=
torch
.
empty
(
(
forward_batch
.
input_ids
.
shape
[
0
],
*
logits
.
shape
[
1
:]),
dtype
=
logits
.
dtype
,
device
=
logits
.
device
,
)
dp_scatter
(
local_logits
,
logits
,
forward_batch
)
return
local_logits
return
do_logits_dp_scatter
python/sglang/srt/layers/logits_processor.py
View file @
3196999f
...
...
@@ -28,7 +28,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather
,
)
from
sglang.srt.layers.dp_attention
import
(
dp_gather
,
dp_gather
_replicate
,
dp_scatter
,
get_attention_dp_rank
,
get_attention_dp_size
,
...
...
@@ -428,7 +428,7 @@ class LogitsProcessor(nn.Module):
logits_metadata
.
gathered_buffer
,
hidden_states
.
clone
(),
)
dp_gather
(
hidden_states
,
local_hidden_states
,
logits_metadata
,
"embedding"
)
dp_gather
_replicate
(
hidden_states
,
local_hidden_states
,
logits_metadata
)
if
hasattr
(
lm_head
,
"weight"
):
logits
=
torch
.
matmul
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
3196999f
...
...
@@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope
,
)
from
sglang.srt.layers.dp_attention
import
(
dp_gather
,
dp_gather
_partial
,
dp_scatter
,
get_attention_dp_size
,
get_attention_tp_rank
,
...
...
@@ -939,47 +939,58 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
residual
is
None
:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Scatter
if
self
.
dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
# Gather
if
get_tensor_model_parallel_world_size
()
>
1
:
# all gather and all reduce
if
self
.
dp_size
!=
1
:
if
get_attention_tp_rank
()
==
0
:
hidden_states
+=
residual
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
hidden_states
,
)
dp_gather
(
hidden
_s
t
at
es
,
local_
hidden_states
,
forward_batch
,
self
.
layer_id
)
dp_gather
_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
dp
_s
c
at
ter
(
residual
,
hidden_states
,
forward_batch
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
else
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
# Fully Connected
hidden_states
=
self
.
mlp
(
hidden_states
)
# Scatter
if
self
.
dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
hidden_states
,
residual
...
...
@@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
# Gather
if
self
.
dp_size
!=
1
:
input_ids
,
local_input_ids
=
(
torch
.
empty
(
(
forward_batch
.
gathered_buffer
.
shape
[
0
],),
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
,
),
input_ids
,
)
dp_gather
(
input_ids
,
local_input_ids
,
forward_batch
,
"embedding"
)
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
...
...
@@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
self
.
dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
...
...
test/srt/test_dp_attention.py
View file @
3196999f
...
...
@@ -11,7 +11,7 @@ from sglang.test.test_utils import (
)
class
TestDPAttention
(
unittest
.
TestCase
):
class
TestDPAttention
DP2TP2
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
...
...
@@ -59,7 +59,3 @@ class TestDPAttention(unittest.TestCase):
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"score"
],
0.8
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment