Unverified Commit 4582931a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "Revert the changes on NCCL symmetric memory" (#10238)

parent d352c29a
...@@ -510,6 +510,17 @@ class GroupCoordinator: ...@@ -510,6 +510,17 @@ class GroupCoordinator:
if self.npu_communicator is not None and not self.npu_communicator.disabled: if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_) return self.npu_communicator.all_reduce(input_)
if (
self.pynccl_comm is not None
and hasattr(input_, "symmetric_memory")
and input_.symmetric_memory
):
with self.pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
):
self.pynccl_comm.all_reduce(input_)
return input_
outplace_all_reduce_method = None outplace_all_reduce_method = None
if ( if (
self.qr_comm is not None self.qr_comm is not None
......
...@@ -13,10 +13,14 @@ from sglang.srt.distributed import ( ...@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BasevLLMParameter, BasevLLMParameter,
BlockQuantScaleParameter, BlockQuantScaleParameter,
...@@ -1311,7 +1315,9 @@ class RowParallelLinear(LinearBase): ...@@ -1311,7 +1315,9 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -11,8 +11,12 @@ from sglang.srt.distributed import ( ...@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank, get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size, get_moe_tensor_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
MoeRunnerConfig, MoeRunnerConfig,
......
...@@ -11,8 +11,12 @@ from sglang.srt.distributed import ( ...@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.parameter import BasevLLMParameter
...@@ -468,10 +472,10 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -468,10 +472,10 @@ class VocabParallelEmbedding(torch.nn.Module):
) )
else: else:
masked_input = input_ masked_input = input_
# Get the embeddings. # Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long()) with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.embedding(self, masked_input.long())
sm.tag(output_parallel)
# Mask the output embedding. # Mask the output embedding.
if self.tp_size > 1: if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
......
...@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union ...@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -34,6 +35,9 @@ from sglang.srt.distributed import ( ...@@ -34,6 +35,9 @@ from sglang.srt.distributed import (
parallel_state, parallel_state,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
...@@ -524,8 +528,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -524,8 +528,12 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
final_hidden_states += shared_output with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if ( if (
self.tp_size > 1 self.tp_size > 1
and not should_allreduce_fusion and not should_allreduce_fusion
...@@ -563,8 +571,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -563,8 +571,11 @@ class DeepseekV2MoE(nn.Module):
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if ( if (
self.tp_size > 1 self.tp_size > 1
and not should_allreduce_fusion and not should_allreduce_fusion
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment