Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4582931a
Unverified
Commit
4582931a
authored
Sep 09, 2025
by
Lianmin Zheng
Committed by
GitHub
Sep 09, 2025
Browse files
Revert "Revert the changes on NCCL symmetric memory" (#10238)
parent
d352c29a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
7 deletions
+43
-7
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+11
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+7
-1
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+4
-0
python/sglang/srt/layers/vocab_parallel_embedding.py
python/sglang/srt/layers/vocab_parallel_embedding.py
+7
-3
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+14
-3
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
4582931a
...
@@ -510,6 +510,17 @@ class GroupCoordinator:
...
@@ -510,6 +510,17 @@ class GroupCoordinator:
if
self
.
npu_communicator
is
not
None
and
not
self
.
npu_communicator
.
disabled
:
if
self
.
npu_communicator
is
not
None
and
not
self
.
npu_communicator
.
disabled
:
return
self
.
npu_communicator
.
all_reduce
(
input_
)
return
self
.
npu_communicator
.
all_reduce
(
input_
)
if
(
self
.
pynccl_comm
is
not
None
and
hasattr
(
input_
,
"symmetric_memory"
)
and
input_
.
symmetric_memory
):
with
self
.
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()
):
self
.
pynccl_comm
.
all_reduce
(
input_
)
return
input_
outplace_all_reduce_method
=
None
outplace_all_reduce_method
=
None
if
(
if
(
self
.
qr_comm
is
not
None
self
.
qr_comm
is
not
None
...
...
python/sglang/srt/layers/linear.py
View file @
4582931a
...
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
...
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
divide
,
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
parallel_state
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
use_symmetric_memory
,
)
from
sglang.srt.layers.parameter
import
(
from
sglang.srt.layers.parameter
import
(
BasevLLMParameter
,
BasevLLMParameter
,
BlockQuantScaleParameter
,
BlockQuantScaleParameter
,
...
@@ -1311,7 +1315,9 @@ class RowParallelLinear(LinearBase):
...
@@ -1311,7 +1315,9 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
sm
.
tag
(
output_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
skip_all_reduce
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
skip_all_reduce
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
4582931a
...
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
...
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size
,
get_moe_expert_parallel_world_size
,
get_moe_tensor_parallel_rank
,
get_moe_tensor_parallel_rank
,
get_moe_tensor_parallel_world_size
,
get_moe_tensor_parallel_world_size
,
get_tp_group
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
use_symmetric_memory
,
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe
import
(
from
sglang.srt.layers.moe
import
(
MoeRunnerConfig
,
MoeRunnerConfig
,
...
...
python/sglang/srt/layers/vocab_parallel_embedding.py
View file @
4582931a
...
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
...
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
divide
,
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
parallel_state
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
use_symmetric_memory
,
)
from
sglang.srt.layers.amx_utils
import
PackWeightMethod
from
sglang.srt.layers.amx_utils
import
PackWeightMethod
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.layers.parameter
import
BasevLLMParameter
from
sglang.srt.layers.parameter
import
BasevLLMParameter
...
@@ -468,10 +472,10 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -468,10 +472,10 @@ class VocabParallelEmbedding(torch.nn.Module):
)
)
else
:
else
:
masked_input
=
input_
masked_input
=
input_
# Get the embeddings.
# Get the embeddings.
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
embedding
(
self
,
masked_input
.
long
())
output_parallel
=
self
.
quant_method
.
embedding
(
self
,
masked_input
.
long
())
sm
.
tag
(
output_parallel
)
# Mask the output embedding.
# Mask the output embedding.
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
4582931a
...
@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
...
@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
tqdm
import
tqdm
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
...
@@ -34,6 +35,9 @@ from sglang.srt.distributed import (
...
@@ -34,6 +35,9 @@ from sglang.srt.distributed import (
parallel_state
,
parallel_state
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
use_symmetric_memory
,
)
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
...
@@ -524,8 +528,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -524,8 +528,12 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
final_hidden_states
+=
shared_output
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
if
(
if
(
self
.
tp_size
>
1
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
should_allreduce_fusion
...
@@ -563,8 +571,11 @@ class DeepseekV2MoE(nn.Module):
...
@@ -563,8 +571,11 @@ class DeepseekV2MoE(nn.Module):
# fused in biased_grouped_topk so we can skip here
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
if
(
if
(
self
.
tp_size
>
1
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
should_allreduce_fusion
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment