Unverified Commit 23407983 authored by Nicolas Castet's avatar Nicolas Castet Committed by GitHub
Browse files

Register allgather/reducescatter buffers with symm memory (#12572)

parent 1357ab02
import os import os
import tempfile import tempfile
from contextlib import nullcontext
import torch import torch
from torch.cuda.memory import CUDAPluggableAllocator from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.server_args import get_global_server_args
nccl_allocator_source = """ nccl_allocator_source = """
...@@ -60,6 +60,9 @@ _cur_device = None ...@@ -60,6 +60,9 @@ _cur_device = None
def is_symmetric_memory_enabled(): def is_symmetric_memory_enabled():
# Import here to avoid circular import
from sglang.srt.server_args import get_global_server_args
return get_global_server_args().enable_symm_mem return get_global_server_args().enable_symm_mem
...@@ -92,7 +95,7 @@ def get_nccl_mem_pool(): ...@@ -92,7 +95,7 @@ def get_nccl_mem_pool():
return _mem_pool return _mem_pool
class use_symmetric_memory: class SymmetricMemoryContext:
""" """
Context manager for using symmetric memory with pynccl. Context manager for using symmetric memory with pynccl.
...@@ -100,25 +103,17 @@ class use_symmetric_memory: ...@@ -100,25 +103,17 @@ class use_symmetric_memory:
by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce
this context manager. All tensors created under this context will be correctly this context manager. All tensors created under this context will be correctly
allocated and registered with a custom allocator. allocated and registered with a custom allocator.
In addition, developers need to manually tag the tensors that will be used as the input/output
of NCCL collectives with `tag(tensor)`.
""" """
def __init__(self, group_coordinator: GroupCoordinator): def __init__(
self.enabled = is_symmetric_memory_enabled() self,
group_coordinator: GroupCoordinator,
if not self.enabled: ):
return
self.group_coordinator = group_coordinator self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing() self.is_graph_capture = torch.cuda.is_current_stream_capturing()
def __enter__(self): def __enter__(self):
if not self.enabled:
return self
assert ( assert (
self.group_coordinator.pynccl_comm is not None self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'" ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
...@@ -139,16 +134,16 @@ class use_symmetric_memory: ...@@ -139,16 +134,16 @@ class use_symmetric_memory:
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
if self.is_graph_capture: if self.is_graph_capture:
torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id) torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id)
def tag(self, tensor: torch.Tensor):
if not self.enabled:
return
tensor.symmetric_memory = True def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False):
disabled = (
not is_symmetric_memory_enabled()
or disabled
or group_coordinator.world_size == 1
)
return SymmetricMemoryContext(group_coordinator) if not disabled else nullcontext()
...@@ -188,6 +188,27 @@ if _supports_custom_op: ...@@ -188,6 +188,27 @@ if _supports_custom_op:
fake_impl=reg_all_gather_into_tensor_fake, fake_impl=reg_all_gather_into_tensor_fake,
) )
def reg_reduce_scatter_tensor(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._reduce_scatter_tensor(output, input)
def reg_reduce_scatter_tensor_fake(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
pass
direct_register_custom_op(
op_name="reg_reduce_scatter_tensor",
op_func=reg_reduce_scatter_tensor,
mutates_args=["output"],
fake_impl=reg_reduce_scatter_tensor_fake,
)
class GroupCoordinator: class GroupCoordinator:
""" """
...@@ -314,10 +335,16 @@ class GroupCoordinator: ...@@ -314,10 +335,16 @@ class GroupCoordinator:
from sglang.srt.distributed.device_communicators.pynccl import ( from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator, PyNcclCommunicator,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled,
use_symmetric_memory,
)
from sglang.srt.distributed.device_communicators.torch_symm_mem import ( from sglang.srt.distributed.device_communicators.torch_symm_mem import (
TorchSymmMemCommunicator, TorchSymmMemCommunicator,
) )
self.is_symmetric_memory_enabled = is_symmetric_memory_enabled
self.use_symmetric_memory = use_symmetric_memory
if is_hip(): if is_hip():
from sglang.srt.distributed.device_communicators.quick_all_reduce import ( from sglang.srt.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce, QuickAllReduce,
...@@ -552,7 +579,7 @@ class GroupCoordinator: ...@@ -552,7 +579,7 @@ class GroupCoordinator:
if self.npu_communicator is not None and not self.npu_communicator.disabled: if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_) return self.npu_communicator.all_reduce(input_)
if self.pynccl_comm is not None and getattr(input_, "symmetric_memory", False): if self.pynccl_comm is not None and self.is_symmetric_memory_enabled():
with self.pynccl_comm.change_state( with self.pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast() enable=True, stream=get_current_device_stream_fast()
): ):
...@@ -627,15 +654,33 @@ class GroupCoordinator: ...@@ -627,15 +654,33 @@ class GroupCoordinator:
else: else:
torch.distributed.all_reduce(input_, group=self.device_group) torch.distributed.all_reduce(input_, group=self.device_group)
def reduce_scatter_tensor( def _reduce_scatter_tensor(
self, self,
output: torch.Tensor, output: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
) -> None: ) -> torch.Tensor:
# TODO(ch-wan): support other backends pynccl_comm = self.pynccl_comm
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group) if pynccl_comm is not None and (
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
):
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
pynccl_comm.reduce_scatter(output, input)
else:
torch.distributed.reduce_scatter_tensor(
output, input, group=self.device_group
)
return output return output
def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor):
if _is_npu or not supports_custom_op():
self._reduce_scatter_tensor(output, input)
else:
torch.ops.sglang.reg_reduce_scatter_tensor(
output, input, group_name=self.unique_name
)
def reduce_scatter( def reduce_scatter(
self, self,
output: torch.Tensor, output: torch.Tensor,
...@@ -682,8 +727,13 @@ class GroupCoordinator: ...@@ -682,8 +727,13 @@ class GroupCoordinator:
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled: if pynccl_comm is not None and (
pynccl_comm.all_gather(output, input) not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
):
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
pynccl_comm.all_gather(output, input)
else: else:
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
output, input, group=self.device_group output, input, group=self.device_group
...@@ -745,9 +795,10 @@ class GroupCoordinator: ...@@ -745,9 +795,10 @@ class GroupCoordinator:
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795 # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size,) + input_size[1:] output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor. # Allocate output tensor.
output_tensor = torch.empty( with self.use_symmetric_memory(self):
output_size, dtype=input_.dtype, device=input_.device output_tensor = torch.empty(
) output_size, dtype=input_.dtype, device=input_.device
)
# All-gather. # All-gather.
if input_.is_cpu: if input_.is_cpu:
...@@ -787,7 +838,7 @@ class GroupCoordinator: ...@@ -787,7 +838,7 @@ class GroupCoordinator:
pynccl_comm is not None and not pynccl_comm.disabled pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv" ), "pynccl is required for all_gatherv"
def _all_gather_single( def _all_gather_allocate_output(
input_: torch.Tensor, sizes: Optional[List[int]] = None input_: torch.Tensor, sizes: Optional[List[int]] = None
): ):
input_size = input_.size() input_size = input_.size()
...@@ -801,19 +852,25 @@ class GroupCoordinator: ...@@ -801,19 +852,25 @@ class GroupCoordinator:
else: else:
output_size = (input_size[0] * world_size,) + input_size[1:] output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor. # Allocate output tensor.
output_tensor = torch.empty( with self.use_symmetric_memory(self, disabled=sizes is not None):
output_size, dtype=input_.dtype, device=input_.device output_tensor = torch.empty(
) output_size, dtype=input_.dtype, device=input_.device
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes) )
return output_tensor return output_tensor, sizes
if isinstance(input_, torch.Tensor): if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes) input_ = [input_]
output_list = [] output_list = []
pynccl_comm.group_start() size_list = []
for inp in input_: for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes)) output_tensor, s = _all_gather_allocate_output(inp, sizes=sizes)
output_list.append(output_tensor)
size_list.append(s)
pynccl_comm.group_start()
for i, inp in enumerate(input_):
pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i])
pynccl_comm.group_end() pynccl_comm.group_end()
return output_list return output_list
......
...@@ -21,8 +21,12 @@ import torch ...@@ -21,8 +21,12 @@ import torch
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor, attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor, attn_tp_reduce_scatter_tensor,
...@@ -34,6 +38,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -34,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size, get_attention_tp_size,
get_global_dp_buffer, get_global_dp_buffer,
get_local_dp_buffer, get_local_dp_buffer,
is_allocation_symmetric,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
...@@ -540,7 +545,12 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -540,7 +545,12 @@ class CommunicateWithAllReduceAndLayerNormFn:
use_layer_norm_before_gather = context.attn_tp_size == 1 use_layer_norm_before_gather = context.attn_tp_size == 1
if use_layer_norm_before_gather and hidden_states.shape[0] != 0: if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
residual = hidden_states residual = hidden_states
hidden_states = layernorm(hidden_states) with use_symmetric_memory(
get_tp_group(),
disabled=not is_allocation_symmetric(),
):
hidden_states = layernorm(hidden_states)
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
get_global_dp_buffer(), get_global_dp_buffer(),
hidden_states, hidden_states,
......
...@@ -17,6 +17,9 @@ from sglang.srt.distributed import ( ...@@ -17,6 +17,9 @@ from sglang.srt.distributed import (
get_tp_group, get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.utils import get_bool_env_var, is_hip from sglang.srt.utils import get_bool_env_var, is_hip
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -86,6 +89,7 @@ class _DpGatheredBufferWrapper: ...@@ -86,6 +89,7 @@ class _DpGatheredBufferWrapper:
_device: torch.device _device: torch.device
_global_dp_buffer_len: int _global_dp_buffer_len: int
_local_dp_buffer_len: int _local_dp_buffer_len: int
_dp_max_padding: bool
_global_num_tokens: Optional[List[int]] _global_num_tokens: Optional[List[int]]
_is_extend_in_batch: bool _is_extend_in_batch: bool
...@@ -100,27 +104,33 @@ class _DpGatheredBufferWrapper: ...@@ -100,27 +104,33 @@ class _DpGatheredBufferWrapper:
cls, cls,
global_dp_buffer_len: int, global_dp_buffer_len: int,
local_dp_buffer_len: int, local_dp_buffer_len: int,
dp_max_padding: bool,
global_num_tokens: Optional[List[int]] = None, global_num_tokens: Optional[List[int]] = None,
): ):
cls._global_dp_buffer_len = global_dp_buffer_len cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len cls._local_dp_buffer_len = local_dp_buffer_len
cls._dp_max_padding = dp_max_padding
cls._global_num_tokens = global_num_tokens cls._global_num_tokens = global_num_tokens
@classmethod @classmethod
def get_global_dp_buffer(cls) -> torch.Tensor: def get_global_dp_buffer(cls) -> torch.Tensor:
return torch.empty( with use_symmetric_memory(get_tp_group()):
(cls._global_dp_buffer_len, cls._hidden_size), buffer = torch.empty(
dtype=cls._dtype, (cls._global_dp_buffer_len, cls._hidden_size),
device=cls._device, dtype=cls._dtype,
) device=cls._device,
)
return buffer
@classmethod @classmethod
def get_local_dp_buffer(cls) -> torch.Tensor: def get_local_dp_buffer(cls) -> torch.Tensor:
return torch.empty( with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):
(cls._local_dp_buffer_len, cls._hidden_size), buffer = torch.empty(
dtype=cls._dtype, (cls._local_dp_buffer_len, cls._hidden_size),
device=cls._device, dtype=cls._dtype,
) device=cls._device,
)
return buffer
@classmethod @classmethod
def get_global_dp_buffer_len(cls) -> int: def get_global_dp_buffer_len(cls) -> int:
...@@ -154,14 +164,19 @@ class _DpGatheredBufferWrapper: ...@@ -154,14 +164,19 @@ class _DpGatheredBufferWrapper:
def get_is_extend_in_batch(cls) -> bool: def get_is_extend_in_batch(cls) -> bool:
return cls._is_extend_in_batch return cls._is_extend_in_batch
@classmethod
def is_dp_max_padding(cls) -> bool:
return cls._dp_max_padding
def set_dp_buffer_len( def set_dp_buffer_len(
global_dp_buffer_len: int, global_dp_buffer_len: int,
local_dp_buffer_len: int, local_dp_buffer_len: int,
dp_max_padding: bool,
global_num_tokens: Optional[List[int]] = None, global_num_tokens: Optional[List[int]] = None,
): ):
_DpGatheredBufferWrapper.set_dp_buffer_len( _DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens global_dp_buffer_len, local_dp_buffer_len, dp_max_padding, global_num_tokens
) )
...@@ -205,6 +220,10 @@ def get_is_extend_in_batch() -> bool: ...@@ -205,6 +220,10 @@ def get_is_extend_in_batch() -> bool:
return _DpGatheredBufferWrapper.get_is_extend_in_batch() return _DpGatheredBufferWrapper.get_is_extend_in_batch()
def is_dp_max_padding() -> bool:
return _DpGatheredBufferWrapper.is_dp_max_padding()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention: if not enable_dp_attention:
return tp_rank, tp_size, 0 return tp_rank, tp_size, 0
...@@ -298,6 +317,10 @@ def is_dp_attention_enabled() -> bool: ...@@ -298,6 +317,10 @@ def is_dp_attention_enabled() -> bool:
return _ENABLE_DP_ATTENTION_FLAG return _ENABLE_DP_ATTENTION_FLAG
def is_allocation_symmetric() -> bool:
return not is_dp_attention_enabled() or is_dp_max_padding()
def get_attention_tp_group() -> GroupCoordinator: def get_attention_tp_group() -> GroupCoordinator:
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!" assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
return _ATTN_TP_GROUP return _ATTN_TP_GROUP
......
...@@ -21,6 +21,7 @@ from sglang.srt.distributed import ( ...@@ -21,6 +21,7 @@ from sglang.srt.distributed import (
from sglang.srt.distributed.device_communicators.pynccl_allocator import ( from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory, use_symmetric_memory,
) )
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BasevLLMParameter, BasevLLMParameter,
BlockQuantScaleParameter, BlockQuantScaleParameter,
...@@ -1372,9 +1373,10 @@ class RowParallelLinear(LinearBase): ...@@ -1372,9 +1373,10 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -97,7 +97,7 @@ def cutlass_fused_experts_fp8( ...@@ -97,7 +97,7 @@ def cutlass_fused_experts_fp8(
b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert. b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
block scaling. Currently, only `True` is supported. Defaults to `True`. block scaling. Currently, only `True` is supported. Defaults to `True`.
output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created.
Returns: Returns:
torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`. torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.
......
...@@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( ...@@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory, use_symmetric_memory,
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
MoeRunnerConfig, MoeRunnerConfig,
get_deepep_mode, get_deepep_mode,
...@@ -841,7 +842,9 @@ class FusedMoE(torch.nn.Module): ...@@ -841,7 +842,9 @@ class FusedMoE(torch.nn.Module):
**kwargs, **kwargs,
) )
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
final_hidden_states = self.dispatcher.combine(combine_input=combine_input) final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
# TODO: should we add some conditions here? # TODO: should we add some conditions here?
...@@ -849,8 +852,6 @@ class FusedMoE(torch.nn.Module): ...@@ -849,8 +852,6 @@ class FusedMoE(torch.nn.Module):
..., :origin_hidden_states_dim ..., :origin_hidden_states_dim
].contiguous() ].contiguous()
sm.tag(final_hidden_states)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
...@@ -1048,10 +1049,10 @@ class FlashInferFP4MoE(FusedMoE): ...@@ -1048,10 +1049,10 @@ class FlashInferFP4MoE(FusedMoE):
router_logits = router_logits.to(torch.float32) router_logits = router_logits.to(torch.float32)
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
symm_output = torch.empty_like(hidden_states) symm_output = torch.empty_like(hidden_states)
sm.tag(symm_output)
result = trtllm_fp4_block_scale_moe( result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits, routing_logits=router_logits,
routing_bias=topk_config.correction_bias.to(hidden_states.dtype), routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
......
...@@ -32,12 +32,17 @@ import torch ...@@ -32,12 +32,17 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb import expert_location_dispatch
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location_dispatch import ( from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo, ExpertLocationDispatchInfo,
topk_ids_logical_to_physical, topk_ids_logical_to_physical,
) )
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
...@@ -279,13 +284,17 @@ class TopK(CustomOp): ...@@ -279,13 +284,17 @@ class TopK(CustomOp):
) )
else: else:
self.topk_config.torch_native = False self.topk_config.torch_native = False
return select_experts( with use_symmetric_memory(
hidden_states=hidden_states, get_tp_group(), disabled=not is_allocation_symmetric()
router_logits=router_logits, ):
topk_config=self.topk_config, topk_output = select_experts(
num_token_non_padded=num_token_non_padded, hidden_states=hidden_states,
expert_location_dispatch_info=expert_location_dispatch_info, router_logits=router_logits,
) topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
return topk_output
def forward_cpu( def forward_cpu(
self, self,
...@@ -386,8 +395,11 @@ class TopK(CustomOp): ...@@ -386,8 +395,11 @@ class TopK(CustomOp):
def empty_topk_output(self, device: torch.device) -> TopKOutput: def empty_topk_output(self, device: torch.device) -> TopKOutput:
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) with use_symmetric_memory(
topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) get_tp_group(), disabled=not is_allocation_symmetric()
):
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
# FIXME: router_logits should be of size (0, num_experts) # FIXME: router_logits should be of size (0, num_experts)
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
return StandardTopKOutput(topk_weights, topk_ids, router_logits) return StandardTopKOutput(topk_weights, topk_ids, router_logits)
......
...@@ -10,6 +10,12 @@ import torch.nn.functional as F ...@@ -10,6 +10,12 @@ import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
try: try:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
...@@ -1033,9 +1039,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1033,9 +1039,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if get_moe_runner_backend().is_cutlass(): if get_moe_runner_backend().is_cutlass():
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
symm_output = torch.empty_like(x) symm_output = torch.empty_like(x)
sm.tag(symm_output)
topk_weights, topk_ids, _ = dispatch_output.topk_output topk_weights, topk_ids, _ = dispatch_output.topk_output
output = cutlass_fused_experts_fp8( output = cutlass_fused_experts_fp8(
...@@ -1208,12 +1215,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1208,12 +1215,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
else topk_config.correction_bias.to(x.dtype) else topk_config.correction_bias.to(x.dtype)
) )
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe. # FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager. # so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager. # If the bug is fixed, we can only put the output tensor allocation under the context manager.
output = trtllm_fp8_block_scale_moe( return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32), routing_logits=router_logits.to(torch.float32),
routing_bias=correction_bias, routing_bias=correction_bias,
hidden_states=a_q, hidden_states=a_q,
...@@ -1238,8 +1247,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1238,8 +1247,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_method_type=2, # DeepSeek-styled routing method routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False, use_shuffled_weight=False,
) )
sm.tag(output)
return output
def maybe_apply_hip_fused_experts( def maybe_apply_hip_fused_experts(
self, self,
......
...@@ -11,7 +11,11 @@ from sglang.srt.distributed import get_tp_group ...@@ -11,7 +11,11 @@ from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import ( from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory, use_symmetric_memory,
) )
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer from sglang.srt.layers.dp_attention import (
get_dp_global_num_tokens,
get_local_dp_buffer,
is_allocation_symmetric,
)
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
MoeRunner, MoeRunner,
MoeRunnerBackend, MoeRunnerBackend,
...@@ -663,7 +667,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -663,7 +667,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
None if correction_bias is None else correction_bias.to(torch.bfloat16) None if correction_bias is None else correction_bias.to(torch.bfloat16)
) )
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe. # FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager. # so we put the whole function under the ``use_symmetric_memory`` context manager.
...@@ -693,7 +699,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -693,7 +699,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
tile_tokens_dim=None, tile_tokens_dim=None,
routing_method_type=routing_method_type, routing_method_type=routing_method_type,
) )
sm.tag(output)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
...@@ -1581,38 +1586,42 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1581,38 +1586,42 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output_dtype = x.dtype output_dtype = x.dtype
original_col = x.shape[1]
x_sf = None x_sf = None
if should_use_flashinfer_cutlass_moe_fp4_allgather(): if should_use_flashinfer_cutlass_moe_fp4_allgather():
from flashinfer import nvfp4_block_scale_interleave from flashinfer import nvfp4_block_scale_interleave
# Quantize before comm, swizzle after. # Quantize before comm, swizzle after.
if x.shape[0] > 0: with use_symmetric_memory(
x, x_sf = fp4_quantize_flashinfer( get_tp_group(), disabled=not is_allocation_symmetric()
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False ):
) if x.shape[0] > 0:
else: x, x_sf = fp4_quantize_flashinfer(
x_col = x.shape[1] x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device) )
x_sf = torch.zeros( else:
0, x_col // 16, dtype=torch.uint8, device=x.device x_col = x.shape[1]
) x = torch.zeros(
0, x_col // 2, dtype=torch.uint8, device=x.device
)
x_sf = torch.zeros(
0, x_col // 16, dtype=torch.uint8, device=x.device
)
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv( topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens() [topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
) )
x_sf = nvfp4_block_scale_interleave(x_sf) x_sf = nvfp4_block_scale_interleave(x_sf)
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(
# The x might be packed in the case of fp4. So, use the output dim of the get_tp_group(), disabled=not is_allocation_symmetric()
# weight of the second GEMM. ):
symm_output = torch.empty( symm_output = torch.empty(
x.shape[0], x.shape[0], original_col, dtype=output_dtype, device=x.device
layer.w2_weight.shape[1],
dtype=output_dtype,
device=x.device,
) )
sm.tag(symm_output)
output = flashinfer_cutlass_fused_moe( output = flashinfer_cutlass_fused_moe(
output=symm_output,
input=x, input=x,
token_selected_experts=topk_ids.to(torch.int), token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights, token_final_scales=topk_weights,
...@@ -1633,7 +1642,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1633,7 +1642,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_size=layer.moe_tp_size, tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank, tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]), tune_max_num_tokens=next_power_of_2(x.shape[0]),
output=symm_output,
)[0] )[0]
if should_use_flashinfer_cutlass_moe_fp4_allgather(): if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output output, global_output = get_local_dp_buffer(), output
......
...@@ -26,6 +26,7 @@ from sglang.srt.distributed import get_tp_group ...@@ -26,6 +26,7 @@ from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import ( from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory, use_symmetric_memory,
) )
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend from sglang.srt.layers.moe.utils import get_moe_runner_backend
...@@ -640,10 +641,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -640,10 +641,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
top_k = topk_output.topk_config.top_k top_k = topk_output.topk_config.top_k
router_logits = topk_output.router_logits router_logits = topk_output.router_logits
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
symm_output = torch.empty_like(x) symm_output = torch.empty_like(x)
sm.tag(symm_output)
trtllm_gen_output = trtllm_fp4_block_scale_moe( trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16), router_logits.to(torch.bfloat16),
None, # routing_bias None, # routing_bias
......
...@@ -473,9 +473,8 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -473,9 +473,8 @@ class VocabParallelEmbedding(torch.nn.Module):
else: else:
masked_input = input_ masked_input = input_
# Get the embeddings. # Get the embeddings.
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(get_tp_group(), disabled=not self.enable_tp):
output_parallel = self.quant_method.embedding(self, masked_input.long()) output_parallel = self.quant_method.embedding(self, masked_input.long())
sm.tag(output_parallel)
# Mask the output embedding. # Mask the output embedding.
if self.tp_size > 1: if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
......
...@@ -660,7 +660,11 @@ class CudaGraphRunner: ...@@ -660,7 +660,11 @@ class CudaGraphRunner:
def run_once(): def run_once():
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens) set_dp_buffer_len(
global_dp_buffer_len,
num_tokens,
forward_batch.dp_padding_mode.is_max_len(),
)
set_is_extend_in_batch(False) set_is_extend_in_batch(False)
kwargs = {} kwargs = {}
......
...@@ -719,7 +719,9 @@ class ForwardBatch: ...@@ -719,7 +719,9 @@ class ForwardBatch:
num_tokens = global_num_tokens[0] num_tokens = global_num_tokens[0]
self.global_dp_buffer_len = buffer_len self.global_dp_buffer_len = buffer_len
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens) set_dp_buffer_len(
buffer_len, num_tokens, dp_padding_mode.is_max_len(), global_num_tokens
)
set_is_extend_in_batch(self.is_extend_in_batch) set_is_extend_in_batch(self.is_extend_in_batch)
bs = self.batch_size bs = self.batch_size
......
...@@ -480,7 +480,8 @@ class DeepseekV2MLP(nn.Module): ...@@ -480,7 +480,8 @@ class DeepseekV2MLP(nn.Module):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj( x, _ = self.down_proj(
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter x,
skip_all_reduce=should_allreduce_fusion or use_reduce_scatter,
) )
return x return x
...@@ -814,7 +815,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -814,7 +815,6 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
final_hidden_states += shared_output final_hidden_states += shared_output
if ( if (
self.tp_size > 1 self.tp_size > 1
and not should_allreduce_fusion and not should_allreduce_fusion
...@@ -883,7 +883,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -883,7 +883,9 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states return final_hidden_states
def forward_deepep( def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
shared_output = None shared_output = None
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
......
...@@ -45,6 +45,7 @@ from sglang.srt.layers.communicator import ( ...@@ -45,6 +45,7 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
is_allocation_symmetric,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -482,12 +483,13 @@ class Glm4MoeSparseMoeBlock(nn.Module): ...@@ -482,12 +483,13 @@ class Glm4MoeSparseMoeBlock(nn.Module):
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: with use_symmetric_memory(
parallel_state.get_tp_group(), disabled=not is_allocation_symmetric()
):
final_hidden_states_out = torch.empty_like(final_hidden_states) final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if ( if (
self.tp_size > 1 self.tp_size > 1
and not should_allreduce_fusion and not should_allreduce_fusion
...@@ -517,11 +519,12 @@ class Glm4MoeSparseMoeBlock(nn.Module): ...@@ -517,11 +519,12 @@ class Glm4MoeSparseMoeBlock(nn.Module):
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: with use_symmetric_memory(
parallel_state.get_tp_group(), disabled=not is_allocation_symmetric()
):
final_hidden_states_out = torch.empty_like(final_hidden_states) final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if ( if (
self.tp_size > 1 self.tp_size > 1
and not should_allreduce_fusion and not should_allreduce_fusion
......
...@@ -85,6 +85,7 @@ class _StageExecutor: ...@@ -85,6 +85,7 @@ class _StageExecutor:
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
self._local_dp_buffer_len = forward_batch.input_ids.shape[0] self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
self._global_num_tokens = forward_batch.global_num_tokens_cpu self._global_num_tokens = forward_batch.global_num_tokens_cpu
self._is_dp_max_padding = forward_batch.dp_padding_mode.is_max_len()
def next(self): def next(self):
assert not self.done assert not self.done
...@@ -95,6 +96,7 @@ class _StageExecutor: ...@@ -95,6 +96,7 @@ class _StageExecutor:
set_dp_buffer_len( set_dp_buffer_len(
self._global_dp_buffer_len, self._global_dp_buffer_len,
self._local_dp_buffer_len, self._local_dp_buffer_len,
self._is_dp_max_padding,
self._global_num_tokens, self._global_num_tokens,
) )
......
...@@ -263,7 +263,11 @@ class EAGLEDraftCudaGraphRunner: ...@@ -263,7 +263,11 @@ class EAGLEDraftCudaGraphRunner:
def run_once(): def run_once():
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens) set_dp_buffer_len(
global_dp_buffer_len,
num_tokens,
forward_batch.dp_padding_mode.is_max_len(),
)
set_is_extend_in_batch(False) set_is_extend_in_batch(False)
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
......
...@@ -294,7 +294,11 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -294,7 +294,11 @@ class EAGLEDraftExtendCudaGraphRunner:
def run_once(): def run_once():
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens) set_dp_buffer_len(
global_dp_buffer_len,
num_tokens,
forward_batch.dp_padding_mode.is_max_len(),
)
set_is_extend_in_batch(False) set_is_extend_in_batch(False)
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment