Unverified Commit 23407983 authored by Nicolas Castet's avatar Nicolas Castet Committed by GitHub
Browse files

Register allgather/reducescatter buffers with symm memory (#12572)

parent 1357ab02
import os
import tempfile
from contextlib import nullcontext
import torch
from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.server_args import get_global_server_args
nccl_allocator_source = """
......@@ -60,6 +60,9 @@ _cur_device = None
def is_symmetric_memory_enabled():
# Import here to avoid circular import
from sglang.srt.server_args import get_global_server_args
return get_global_server_args().enable_symm_mem
......@@ -92,7 +95,7 @@ def get_nccl_mem_pool():
return _mem_pool
class use_symmetric_memory:
class SymmetricMemoryContext:
"""
Context manager for using symmetric memory with pynccl.
......@@ -100,25 +103,17 @@ class use_symmetric_memory:
by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce
this context manager. All tensors created under this context will be correctly
allocated and registered with a custom allocator.
In addition, developers need to manually tag the tensors that will be used as the input/output
of NCCL collectives with `tag(tensor)`.
"""
def __init__(self, group_coordinator: GroupCoordinator):
self.enabled = is_symmetric_memory_enabled()
if not self.enabled:
return
def __init__(
self,
group_coordinator: GroupCoordinator,
):
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
def __enter__(self):
if not self.enabled:
return self
assert (
self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
......@@ -139,16 +134,16 @@ class use_symmetric_memory:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
if self.is_graph_capture:
torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id)
def tag(self, tensor: torch.Tensor):
if not self.enabled:
return
tensor.symmetric_memory = True
def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False):
disabled = (
not is_symmetric_memory_enabled()
or disabled
or group_coordinator.world_size == 1
)
return SymmetricMemoryContext(group_coordinator) if not disabled else nullcontext()
......@@ -188,6 +188,27 @@ if _supports_custom_op:
fake_impl=reg_all_gather_into_tensor_fake,
)
def reg_reduce_scatter_tensor(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._reduce_scatter_tensor(output, input)
def reg_reduce_scatter_tensor_fake(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
pass
direct_register_custom_op(
op_name="reg_reduce_scatter_tensor",
op_func=reg_reduce_scatter_tensor,
mutates_args=["output"],
fake_impl=reg_reduce_scatter_tensor_fake,
)
class GroupCoordinator:
"""
......@@ -314,10 +335,16 @@ class GroupCoordinator:
from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled,
use_symmetric_memory,
)
from sglang.srt.distributed.device_communicators.torch_symm_mem import (
TorchSymmMemCommunicator,
)
self.is_symmetric_memory_enabled = is_symmetric_memory_enabled
self.use_symmetric_memory = use_symmetric_memory
if is_hip():
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce,
......@@ -552,7 +579,7 @@ class GroupCoordinator:
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)
if self.pynccl_comm is not None and getattr(input_, "symmetric_memory", False):
if self.pynccl_comm is not None and self.is_symmetric_memory_enabled():
with self.pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
......@@ -627,15 +654,33 @@ class GroupCoordinator:
else:
torch.distributed.all_reduce(input_, group=self.device_group)
def reduce_scatter_tensor(
def _reduce_scatter_tensor(
self,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
# TODO(ch-wan): support other backends
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
) -> torch.Tensor:
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and (
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
):
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
pynccl_comm.reduce_scatter(output, input)
else:
torch.distributed.reduce_scatter_tensor(
output, input, group=self.device_group
)
return output
def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor):
if _is_npu or not supports_custom_op():
self._reduce_scatter_tensor(output, input)
else:
torch.ops.sglang.reg_reduce_scatter_tensor(
output, input, group_name=self.unique_name
)
def reduce_scatter(
self,
output: torch.Tensor,
......@@ -682,8 +727,13 @@ class GroupCoordinator:
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_gather(output, input)
if pynccl_comm is not None and (
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
):
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
pynccl_comm.all_gather(output, input)
else:
torch.distributed.all_gather_into_tensor(
output, input, group=self.device_group
......@@ -745,9 +795,10 @@ class GroupCoordinator:
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
with self.use_symmetric_memory(self):
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
if input_.is_cpu:
......@@ -787,7 +838,7 @@ class GroupCoordinator:
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv"
def _all_gather_single(
def _all_gather_allocate_output(
input_: torch.Tensor, sizes: Optional[List[int]] = None
):
input_size = input_.size()
......@@ -801,19 +852,25 @@ class GroupCoordinator:
else:
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
return output_tensor
with self.use_symmetric_memory(self, disabled=sizes is not None):
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
return output_tensor, sizes
if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
input_ = [input_]
output_list = []
pynccl_comm.group_start()
size_list = []
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
output_tensor, s = _all_gather_allocate_output(inp, sizes=sizes)
output_list.append(output_tensor)
size_list.append(s)
pynccl_comm.group_start()
for i, inp in enumerate(input_):
pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i])
pynccl_comm.group_end()
return output_list
......
......@@ -21,8 +21,12 @@ import torch
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
......@@ -34,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
get_global_dp_buffer,
get_local_dp_buffer,
is_allocation_symmetric,
is_dp_attention_enabled,
)
from sglang.srt.layers.moe import (
......@@ -540,7 +545,12 @@ class CommunicateWithAllReduceAndLayerNormFn:
use_layer_norm_before_gather = context.attn_tp_size == 1
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
residual = hidden_states
hidden_states = layernorm(hidden_states)
with use_symmetric_memory(
get_tp_group(),
disabled=not is_allocation_symmetric(),
):
hidden_states = layernorm(hidden_states)
hidden_states, local_hidden_states = (
get_global_dp_buffer(),
hidden_states,
......
......@@ -17,6 +17,9 @@ from sglang.srt.distributed import (
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.utils import get_bool_env_var, is_hip
if TYPE_CHECKING:
......@@ -86,6 +89,7 @@ class _DpGatheredBufferWrapper:
_device: torch.device
_global_dp_buffer_len: int
_local_dp_buffer_len: int
_dp_max_padding: bool
_global_num_tokens: Optional[List[int]]
_is_extend_in_batch: bool
......@@ -100,27 +104,33 @@ class _DpGatheredBufferWrapper:
cls,
global_dp_buffer_len: int,
local_dp_buffer_len: int,
dp_max_padding: bool,
global_num_tokens: Optional[List[int]] = None,
):
cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len
cls._dp_max_padding = dp_max_padding
cls._global_num_tokens = global_num_tokens
@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
with use_symmetric_memory(get_tp_group()):
buffer = torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
return buffer
@classmethod
def get_local_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):
buffer = torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
return buffer
@classmethod
def get_global_dp_buffer_len(cls) -> int:
......@@ -154,14 +164,19 @@ class _DpGatheredBufferWrapper:
def get_is_extend_in_batch(cls) -> bool:
return cls._is_extend_in_batch
@classmethod
def is_dp_max_padding(cls) -> bool:
return cls._dp_max_padding
def set_dp_buffer_len(
global_dp_buffer_len: int,
local_dp_buffer_len: int,
dp_max_padding: bool,
global_num_tokens: Optional[List[int]] = None,
):
_DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
global_dp_buffer_len, local_dp_buffer_len, dp_max_padding, global_num_tokens
)
......@@ -205,6 +220,10 @@ def get_is_extend_in_batch() -> bool:
return _DpGatheredBufferWrapper.get_is_extend_in_batch()
def is_dp_max_padding() -> bool:
return _DpGatheredBufferWrapper.is_dp_max_padding()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
......@@ -298,6 +317,10 @@ def is_dp_attention_enabled() -> bool:
return _ENABLE_DP_ATTENTION_FLAG
def is_allocation_symmetric() -> bool:
return not is_dp_attention_enabled() or is_dp_max_padding()
def get_attention_tp_group() -> GroupCoordinator:
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
return _ATTN_TP_GROUP
......
......@@ -21,6 +21,7 @@ from sglang.srt.distributed import (
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.parameter import (
BasevLLMParameter,
BlockQuantScaleParameter,
......@@ -1372,9 +1373,10 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory(get_tp_group()) as sm:
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel)
......
......@@ -97,7 +97,7 @@ def cutlass_fused_experts_fp8(
b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
block scaling. Currently, only `True` is supported. Defaults to `True`.
output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created.
Returns:
torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.
......
......@@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import (
MoeRunnerConfig,
get_deepep_mode,
......@@ -841,7 +842,9 @@ class FusedMoE(torch.nn.Module):
**kwargs,
)
with use_symmetric_memory(get_tp_group()) as sm:
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
# TODO: should we add some conditions here?
......@@ -849,8 +852,6 @@ class FusedMoE(torch.nn.Module):
..., :origin_hidden_states_dim
].contiguous()
sm.tag(final_hidden_states)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
......@@ -1048,10 +1049,10 @@ class FlashInferFP4MoE(FusedMoE):
router_logits = router_logits.to(torch.float32)
with use_symmetric_memory(get_tp_group()) as sm:
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
symm_output = torch.empty_like(hidden_states)
sm.tag(symm_output)
result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
......
......@@ -32,12 +32,17 @@ import torch
import torch.nn.functional as F
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb import expert_location_dispatch
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo,
topk_ids_logical_to_physical,
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.utils import (
cpu_has_amx_support,
......@@ -279,13 +284,17 @@ class TopK(CustomOp):
)
else:
self.topk_config.torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
topk_output = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
return topk_output
def forward_cpu(
self,
......@@ -386,8 +395,11 @@ class TopK(CustomOp):
def empty_topk_output(self, device: torch.device) -> TopKOutput:
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
# FIXME: router_logits should be of size (0, num_experts)
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
......
......@@ -10,6 +10,12 @@ import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
try:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
......@@ -1033,9 +1039,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if get_moe_runner_backend().is_cutlass():
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
with use_symmetric_memory(get_tp_group()) as sm:
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
symm_output = torch.empty_like(x)
sm.tag(symm_output)
topk_weights, topk_ids, _ = dispatch_output.topk_output
output = cutlass_fused_experts_fp8(
......@@ -1208,12 +1215,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
else topk_config.correction_bias.to(x.dtype)
)
with use_symmetric_memory(get_tp_group()) as sm:
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
output = trtllm_fp8_block_scale_moe(
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=correction_bias,
hidden_states=a_q,
......@@ -1238,8 +1247,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
sm.tag(output)
return output
def maybe_apply_hip_fused_experts(
self,
......
......@@ -11,7 +11,11 @@ from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
from sglang.srt.layers.dp_attention import (
get_dp_global_num_tokens,
get_local_dp_buffer,
is_allocation_symmetric,
)
from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
......@@ -663,7 +667,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
None if correction_bias is None else correction_bias.to(torch.bfloat16)
)
with use_symmetric_memory(get_tp_group()) as sm:
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
......@@ -693,7 +699,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
tile_tokens_dim=None,
routing_method_type=routing_method_type,
)
sm.tag(output)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
......@@ -1581,38 +1586,42 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output_dtype = x.dtype
original_col = x.shape[1]
x_sf = None
if should_use_flashinfer_cutlass_moe_fp4_allgather():
from flashinfer import nvfp4_block_scale_interleave
# Quantize before comm, swizzle after.
if x.shape[0] > 0:
x, x_sf = fp4_quantize_flashinfer(
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
)
else:
x_col = x.shape[1]
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
x_sf = torch.zeros(
0, x_col // 16, dtype=torch.uint8, device=x.device
)
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
if x.shape[0] > 0:
x, x_sf = fp4_quantize_flashinfer(
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
)
else:
x_col = x.shape[1]
x = torch.zeros(
0, x_col // 2, dtype=torch.uint8, device=x.device
)
x_sf = torch.zeros(
0, x_col // 16, dtype=torch.uint8, device=x.device
)
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
)
x_sf = nvfp4_block_scale_interleave(x_sf)
with use_symmetric_memory(get_tp_group()) as sm:
# The x might be packed in the case of fp4. So, use the output dim of the
# weight of the second GEMM.
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
symm_output = torch.empty(
x.shape[0],
layer.w2_weight.shape[1],
dtype=output_dtype,
device=x.device,
x.shape[0], original_col, dtype=output_dtype, device=x.device
)
sm.tag(symm_output)
output = flashinfer_cutlass_fused_moe(
output=symm_output,
input=x,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
......@@ -1633,7 +1642,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
output=symm_output,
)[0]
if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output
......
......@@ -26,6 +26,7 @@ from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend
......@@ -640,10 +641,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
top_k = topk_output.topk_config.top_k
router_logits = topk_output.router_logits
with use_symmetric_memory(get_tp_group()) as sm:
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
symm_output = torch.empty_like(x)
sm.tag(symm_output)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
None, # routing_bias
......
......@@ -473,9 +473,8 @@ class VocabParallelEmbedding(torch.nn.Module):
else:
masked_input = input_
# Get the embeddings.
with use_symmetric_memory(get_tp_group()) as sm:
with use_symmetric_memory(get_tp_group(), disabled=not self.enable_tp):
output_parallel = self.quant_method.embedding(self, masked_input.long())
sm.tag(output_parallel)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
......
......@@ -660,7 +660,11 @@ class CudaGraphRunner:
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
set_dp_buffer_len(
global_dp_buffer_len,
num_tokens,
forward_batch.dp_padding_mode.is_max_len(),
)
set_is_extend_in_batch(False)
kwargs = {}
......
......@@ -719,7 +719,9 @@ class ForwardBatch:
num_tokens = global_num_tokens[0]
self.global_dp_buffer_len = buffer_len
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
set_dp_buffer_len(
buffer_len, num_tokens, dp_padding_mode.is_max_len(), global_num_tokens
)
set_is_extend_in_batch(self.is_extend_in_batch)
bs = self.batch_size
......
......@@ -480,7 +480,8 @@ class DeepseekV2MLP(nn.Module):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
x,
skip_all_reduce=should_allreduce_fusion or use_reduce_scatter,
)
return x
......@@ -814,7 +815,6 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states += shared_output
if (
self.tp_size > 1
and not should_allreduce_fusion
......@@ -883,7 +883,9 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states
def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
shared_output = None
if hidden_states.shape[0] > 0:
......
......@@ -45,6 +45,7 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
is_allocation_symmetric,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
......@@ -482,12 +483,13 @@ class Glm4MoeSparseMoeBlock(nn.Module):
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
with use_symmetric_memory(
parallel_state.get_tp_group(), disabled=not is_allocation_symmetric()
):
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if (
self.tp_size > 1
and not should_allreduce_fusion
......@@ -517,11 +519,12 @@ class Glm4MoeSparseMoeBlock(nn.Module):
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
with use_symmetric_memory(
parallel_state.get_tp_group(), disabled=not is_allocation_symmetric()
):
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if (
self.tp_size > 1
and not should_allreduce_fusion
......
......@@ -85,6 +85,7 @@ class _StageExecutor:
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
self._global_num_tokens = forward_batch.global_num_tokens_cpu
self._is_dp_max_padding = forward_batch.dp_padding_mode.is_max_len()
def next(self):
assert not self.done
......@@ -95,6 +96,7 @@ class _StageExecutor:
set_dp_buffer_len(
self._global_dp_buffer_len,
self._local_dp_buffer_len,
self._is_dp_max_padding,
self._global_num_tokens,
)
......
......@@ -263,7 +263,11 @@ class EAGLEDraftCudaGraphRunner:
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
set_dp_buffer_len(
global_dp_buffer_len,
num_tokens,
forward_batch.dp_padding_mode.is_max_len(),
)
set_is_extend_in_batch(False)
# Backup two fields, which will be modified in-place in `draft_forward`.
......
......@@ -294,7 +294,11 @@ class EAGLEDraftExtendCudaGraphRunner:
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
set_dp_buffer_len(
global_dp_buffer_len,
num_tokens,
forward_batch.dp_padding_mode.is_max_len(),
)
set_is_extend_in_batch(False)
# Backup two fields, which will be modified in-place in `draft_forward`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment