Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
127998cc
Unverified
Commit
127998cc
authored
Feb 25, 2025
by
Nicolas Castet
Committed by
GitHub
Feb 25, 2025
Browse files
Fix allgather ops inside cuda graphs (#3709)
parent
c0bb9eb3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
7 deletions
+41
-7
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+39
-3
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-4
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
127998cc
...
...
@@ -139,6 +139,27 @@ if supports_custom_op():
fake_impl
=
outplace_all_reduce_fake
,
)
def
reg_all_gather_into_tensor
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_gather_into_tensor
(
output
,
input
)
def
reg_all_gather_into_tensor_fake
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
pass
direct_register_custom_op
(
op_name
=
"reg_all_gather_into_tensor"
,
op_func
=
reg_all_gather_into_tensor
,
mutates_args
=
[],
fake_impl
=
reg_all_gather_into_tensor_fake
,
)
class
GroupCoordinator
:
"""
...
...
@@ -414,6 +435,23 @@ class GroupCoordinator:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
def
_all_gather_into_tensor
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
):
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
all_gather
(
output
,
input
)
else
:
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input
,
group
=
self
.
device_group
)
def
all_gather_into_tensor
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
):
if
not
supports_custom_op
():
self
.
_all_gather_into_tensor
(
output
,
input
)
else
:
torch
.
ops
.
sglang
.
reg_all_gather_into_tensor
(
output
,
input
,
group_name
=
self
.
unique_name
)
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
...
...
@@ -441,9 +479,7 @@ class GroupCoordinator:
output_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
self
.
all_gather_into_tensor
(
output_tensor
,
input_
)
# Reshape
output_tensor
=
output_tensor
.
reshape
((
world_size
,)
+
input_size
)
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
127998cc
...
...
@@ -824,9 +824,7 @@ def all_gather(
input_tensor
,
(
0
,
0
,
0
,
max_len
-
input_tensor
.
shape
[
0
])
)
torch
.
distributed
.
all_gather_into_tensor
(
forward_batch
.
gathered_buffer
,
padded_tensor
,
group
=
group
)
group
.
all_gather_into_tensor
(
forward_batch
.
gathered_buffer
,
padded_tensor
)
gathered_tensors
=
torch
.
concat
(
[
...
...
@@ -862,7 +860,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if
self
.
enable_dp_attention
:
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_group
=
get_tp_group
()
.
device_group
self
.
tp_group
=
get_tp_group
()
if
not
global_server_args_dict
[
"disable_mla"
]:
self
.
self_attn
=
DeepseekV2AttentionMLA
(
config
=
config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment