"examples/git@developer.sourcefind.cn:change/sglang.git" did not exist on "b149b39353ab5295cca80aa13ec1903e9c7e60d9"
Unverified Commit 7a21d8b2 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Reduce the overhead of nccl symmetric memory (#12524)


Co-authored-by: default avatarNicolas Castet <ncastet@nvidia.com>
parent d36639ee
......@@ -19,6 +19,7 @@ from sglang.srt.distributed.device_communicators.pynccl_wrapper import (
ncclUniqueId,
)
from sglang.srt.distributed.utils import StatelessProcessGroup
from sglang.srt.utils.common import get_current_device_stream_fast
logger = logging.getLogger(__name__)
......@@ -137,7 +138,7 @@ class PyNcclCommunicator:
if stream is not None:
return stream
if self.use_current_stream:
return torch.cuda.current_stream()
return get_current_device_stream_fast()
return self.stream
def all_reduce(
......
import os
import tempfile
import torch
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator
......@@ -9,13 +9,22 @@ from sglang.srt.server_args import get_global_server_args
nccl_allocator_source = """
#include <nccl.h>
extern "C" {
void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
const char *str_val = getenv("SGLANG_TMP_NCCL_COMM_VALUE");
char *endptr;
void* int_val = (void *)strtoull(str_val, &endptr, 0);
ncclComm_t comm = (ncclComm_t)(int_val);
ncclWindow_t win;
ncclResult_t err2 = ncclCommWindowRegister(comm, ptr, size, &win, NCCL_WIN_COLL_SYMMETRIC);
return ptr;
}
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
......@@ -27,8 +36,8 @@ void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
_allocator = None
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None
_cur_device = None
def is_symmetric_memory_enabled():
......@@ -41,7 +50,7 @@ def set_graph_pool_id(graph_pool_id):
def get_nccl_mem_pool():
global _allocator, _mem_pool
global _allocator, _mem_pool, _cur_device
if _mem_pool is None:
out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator"
......@@ -60,74 +69,67 @@ def get_nccl_mem_pool():
"nccl_free_plug",
).allocator()
_mem_pool = torch.cuda.MemPool(_allocator)
_cur_device = torch.cuda.current_device()
return _mem_pool
class use_symmetric_memory:
"""
Context manager for using symmetric memory with pynccl.
To Utilize the symmetric memory feature in NCCL, the buffers need to be allocated
by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce
this context manager. All tensors created under this context will be correctly
allocated and registered with a custom allocator.
In addition, developers need to manually tag the tensors that will be used as the input/output
of NCCL collectives with `tag(tensor)`.
"""
def __init__(self, group_coordinator: GroupCoordinator):
if not is_symmetric_memory_enabled():
self.group_coordinator = None
self._mem_pool_ctx = None
self.is_graph_capture = None
self.device = None
self.pre_2_8_0 = None
else:
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
self.enabled = is_symmetric_memory_enabled()
if not self.enabled:
return
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
def __enter__(self):
if not is_symmetric_memory_enabled():
if not self.enabled:
return self
assert (
self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
assert (
self.group_coordinator.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
if self.is_graph_capture:
assert (
_graph_pool_id is not None
), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph
if self.pre_2_8_0:
torch._C._cuda_endAllocateCurrentStreamToPool(
self.device, _graph_pool_id
)
else:
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
torch._C._cuda_endAllocateToPool(_cur_device, _graph_pool_id)
self._mem_pool_ctx.__enter__()
return self
def tag(self, tensor: torch.Tensor):
if not is_symmetric_memory_enabled():
return
tensor.symmetric_memory = True
# Set the env var to pass this argument to the C functions.
os.environ["SGLANG_TMP_NCCL_COMM_VALUE"] = str(
self.group_coordinator.pynccl_comm.comm.value
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if not is_symmetric_memory_enabled():
if not self.enabled:
return
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
for segment in get_nccl_mem_pool().snapshot():
if segment["address"] not in _registered_base_addrs:
if segment["stream"] == 0 and self.pre_2_8_0:
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
# See https://github.com/pytorch/pytorch/issues/152861
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
continue
self.group_coordinator.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"]
)
_registered_base_addrs.add(segment["address"])
if self.is_graph_capture:
if self.pre_2_8_0:
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
else:
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id
)
torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id)
def tag(self, tensor: torch.Tensor):
if not self.enabled:
return
tensor.symmetric_memory = True
......@@ -43,6 +43,7 @@ from sglang.srt.environ import envs
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
get_current_device_stream_fast,
get_int_env_var,
get_local_ip_auto,
is_cpu,
......@@ -466,7 +467,7 @@ class GroupCoordinator:
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = self.device_module.current_stream()
curr_stream = get_current_device_stream_fast()
if curr_stream != stream:
stream.wait_stream(curr_stream)
......@@ -500,7 +501,7 @@ class GroupCoordinator:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
enable=True, stream=get_current_device_stream_fast()
)
pymscclpp_comm = self.pymscclpp_comm
......@@ -551,13 +552,9 @@ class GroupCoordinator:
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)
if (
self.pynccl_comm is not None
and hasattr(input_, "symmetric_memory")
and input_.symmetric_memory
):
if self.pynccl_comm is not None and getattr(input_, "symmetric_memory", False):
with self.pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
enable=True, stream=get_current_device_stream_fast()
):
self.pynccl_comm.all_reduce(input_)
return input_
......@@ -658,7 +655,7 @@ class GroupCoordinator:
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
enable=True, stream=get_current_device_stream_fast()
):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
......@@ -784,7 +781,7 @@ class GroupCoordinator:
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
enable=True, stream=get_current_device_stream_fast()
):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
......
......@@ -677,10 +677,16 @@ class Engine(EngineBase):
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
if "NCCL_CUMEM_ENABLE" not in os.environ:
if "NCCL_CUMEM_ENABLE" not in os.environ or server_args.enable_symm_mem:
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
if not server_args.enable_symm_mem:
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
if (
"NCCL_NVLS_ENABLE" not in os.environ
or server_args.enable_nccl_nvls
or server_args.enable_symm_mem
):
os.environ["NCCL_NVLS_ENABLE"] = str(
int(server_args.enable_nccl_nvls or server_args.enable_symm_mem)
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
......
......@@ -13,7 +13,7 @@ from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
get_tp_group,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
......@@ -1372,7 +1372,7 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
with use_symmetric_memory(get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
......
"""CUTLASS based Fused MoE kernels."""
from typing import Optional
import torch
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
......@@ -40,6 +42,7 @@ def cutlass_fused_experts_fp8(
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
use_fp8_blockscale: bool = True,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.
......@@ -200,9 +203,11 @@ def cutlass_fused_experts_fp8(
workspace,
)
result = torch.empty((m, k), device=device, dtype=out_dtype)
apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))
return result
if output is None:
output = torch.empty((m, k), device=device, dtype=out_dtype)
apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype))
return output
FLOAT4_E2M1_MAX = 6.0
......
......@@ -14,6 +14,9 @@ from sglang.srt.distributed import (
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe import (
MoeRunnerConfig,
......@@ -55,11 +58,6 @@ from sglang.srt.utils import (
if is_flashinfer_available():
from flashinfer import RoutingMethodType, fp4_quantize
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None
if should_use_flashinfer_trtllm_moe():
......@@ -68,6 +66,10 @@ if should_use_flashinfer_trtllm_moe():
except ImportError:
trtllm_fp4_block_scale_moe = None
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
logger = logging.getLogger(__name__)
......@@ -839,12 +841,16 @@ class FusedMoE(torch.nn.Module):
dispatch_output=dispatch_output,
**kwargs,
)
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
# TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim
].contiguous()
with use_symmetric_memory(get_tp_group()) as sm:
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
# TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim
].contiguous()
sm.tag(final_hidden_states)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
......@@ -980,6 +986,11 @@ class FlashInferFusedMoE(FusedMoE):
),
)
# NOTE for symmetric memory tagging:
# We do not create the context in this function.
# Instead, we create the context and tagging inside each FusedMoEMethodBase
# This can allow fine-grained tagging.
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
......@@ -1040,6 +1051,10 @@ class FlashInferFP4MoE(FusedMoE):
router_logits = router_logits.to(torch.float32)
with use_symmetric_memory(get_tp_group()) as sm:
symm_output = torch.empty_like(hidden_states)
sm.tag(symm_output)
result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
......@@ -1072,6 +1087,7 @@ class FlashInferFP4MoE(FusedMoE):
tile_tokens_dim=None,
routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True,
output=symm_output,
)[0]
return result
......@@ -28,7 +28,10 @@ except ImportError:
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import get_tensor_model_parallel_world_size, get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
......@@ -1025,6 +1028,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self._should_use_cutlass_fused_experts():
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
with use_symmetric_memory(get_tp_group()) as sm:
symm_output = torch.empty_like(x)
sm.tag(symm_output)
topk_weights, topk_ids, _ = dispatch_output.topk_output
output = cutlass_fused_experts_fp8(
x,
......@@ -1048,6 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.problem_sizes1,
self.problem_sizes2,
use_fp8_blockscale=True,
output=symm_output,
)
return StandardCombineInput(hidden_states=output)
......@@ -1211,31 +1219,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
else topk_config.correction_bias.to(x.dtype)
)
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale_inv,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=(
routed_scaling_factor if routed_scaling_factor is not None else 1.0
),
tile_tokens_dim=get_tile_tokens_dim(
x.shape[0], topk_config.top_k, layer.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
with use_symmetric_memory(get_tp_group()) as sm:
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
output = trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale_inv,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=(
routed_scaling_factor if routed_scaling_factor is not None else 1.0
),
tile_tokens_dim=get_tile_tokens_dim(
x.shape[0], topk_config.top_k, layer.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
sm.tag(output)
return output
def maybe_apply_hip_fused_experts(
self,
......
......@@ -8,6 +8,9 @@ import torch
from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
from sglang.srt.layers.moe import (
MoeRunner,
......@@ -659,29 +662,37 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
None if correction_bias is None else correction_bias.to(torch.bfloat16)
)
output = trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits_cast,
routing_bias=routing_bias_cast,
hidden_states=x_fp8,
gemm1_weights=layer.w13_weight,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
gemm2_weights=layer.w2_weight,
output2_scales_scalar=layer.output2_scales_scalar,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=(
routed_scaling_factor if routed_scaling_factor is not None else 1.0
),
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=8, # TODO(brayden): use the FI tile calculation
routing_method_type=routing_method_type,
)
with use_symmetric_memory(get_tp_group()) as sm:
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
output = trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits_cast,
routing_bias=routing_bias_cast,
hidden_states=x_fp8,
gemm1_weights=layer.w13_weight,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
gemm2_weights=layer.w2_weight,
output2_scales_scalar=layer.output2_scales_scalar,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=(
routed_scaling_factor
if routed_scaling_factor is not None
else 1.0
),
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=8, # TODO(brayden): use the FI tile calculation
routing_method_type=routing_method_type,
)
sm.tag(output)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
......@@ -1587,6 +1598,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
x_sf = nvfp4_block_scale_interleave(x_sf)
with use_symmetric_memory(get_tp_group()) as sm:
symm_output = torch.empty(
x.shape[0], x.shape[1], dtype=output_dtype, device=x.device
)
sm.tag(symm_output)
output = flashinfer_cutlass_fused_moe(
input=x,
token_selected_experts=topk_ids.to(torch.int),
......@@ -1608,6 +1625,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
output=symm_output,
)[0]
if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output
......
......@@ -22,6 +22,10 @@ from typing import TYPE_CHECKING, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend
......@@ -70,14 +74,14 @@ _is_hip = is_hip()
if _is_hip:
# import aiter
try:
from aiter import ActivationType, QuantType, dtypes
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.utility.fp4_utils import e8m0_shuffle
except ImportError as err:
ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = (
e8m0_shuffle
) = err
ActivationType = QuantType = fused_moe = dynamic_mxfp4_quant = e8m0_shuffle = (
err
)
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
......@@ -606,8 +610,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if self.use_flashinfer:
# When bf16 mode is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
......@@ -630,7 +632,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
else:
raise NotImplementedError
raise NotImplementedError()
assert x_quant.shape[-1] == self.hidden_size
assert TopKOutputChecker.format_is_bypassed(topk_output)
......@@ -638,6 +640,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
top_k = topk_output.topk_config.top_k
router_logits = topk_output.router_logits
with use_symmetric_memory(get_tp_group()) as sm:
symm_output = torch.empty_like(x)
sm.tag(symm_output)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
None, # routing_bias
......@@ -666,6 +672,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # tile_tokens_dim
1, # routing_method_type, renormalize
True, # do finalize
output=symm_output,
)[0]
return StandardCombineInput(hidden_states=trtllm_gen_output)
......
......@@ -11,7 +11,7 @@ from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
......@@ -473,7 +473,7 @@ class VocabParallelEmbedding(torch.nn.Module):
else:
masked_input = input_
# Get the embeddings.
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
with use_symmetric_memory(get_tp_group()) as sm:
output_parallel = self.quant_method.embedding(self, masked_input.long())
sm.tag(output_parallel)
# Mask the output embedding.
......
......@@ -39,12 +39,8 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
......@@ -758,12 +754,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
final_hidden_states += shared_output
if (
self.tp_size > 1
and not should_allreduce_fusion
......@@ -822,11 +813,8 @@ class DeepseekV2MoE(nn.Module):
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
final_hidden_states += shared_output
if (
self.tp_size > 1
and not should_allreduce_fusion
......
......@@ -3174,7 +3174,7 @@ class ServerArgs:
parser.add_argument(
"--enable-torch-symm-mem",
action="store_true",
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.",
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM100 supports world size 6, 8.",
)
parser.add_argument(
"--disable-overlap-schedule",
......
......@@ -3605,3 +3605,13 @@ def calc_diff(x, y):
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
cached_device_index = -1
def get_current_device_stream_fast():
global cached_device_index
if cached_device_index == -1:
cached_device_index = torch.get_device_module().current_device()
return torch.get_device_module().current_stream(cached_device_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment