"git@developer.sourcefind.cn:change/sglang.git" did not exist on "c44e985dc20ec79dcf4e64a9c1f6b8fa395d853b"
Unverified Commit 7a21d8b2 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Reduce the overhead of nccl symmetric memory (#12524)


Co-authored-by: default avatarNicolas Castet <ncastet@nvidia.com>
parent d36639ee
...@@ -19,6 +19,7 @@ from sglang.srt.distributed.device_communicators.pynccl_wrapper import ( ...@@ -19,6 +19,7 @@ from sglang.srt.distributed.device_communicators.pynccl_wrapper import (
ncclUniqueId, ncclUniqueId,
) )
from sglang.srt.distributed.utils import StatelessProcessGroup from sglang.srt.distributed.utils import StatelessProcessGroup
from sglang.srt.utils.common import get_current_device_stream_fast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -137,7 +138,7 @@ class PyNcclCommunicator: ...@@ -137,7 +138,7 @@ class PyNcclCommunicator:
if stream is not None: if stream is not None:
return stream return stream
if self.use_current_stream: if self.use_current_stream:
return torch.cuda.current_stream() return get_current_device_stream_fast()
return self.stream return self.stream
def all_reduce( def all_reduce(
......
import os
import tempfile import tempfile
import torch import torch
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.distributed.parallel_state import GroupCoordinator
...@@ -9,13 +9,22 @@ from sglang.srt.server_args import get_global_server_args ...@@ -9,13 +9,22 @@ from sglang.srt.server_args import get_global_server_args
nccl_allocator_source = """ nccl_allocator_source = """
#include <nccl.h> #include <nccl.h>
extern "C" { extern "C" {
void* nccl_alloc_plug(size_t size, int device, void* stream) { void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr; void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size); ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
const char *str_val = getenv("SGLANG_TMP_NCCL_COMM_VALUE");
char *endptr;
void* int_val = (void *)strtoull(str_val, &endptr, 0);
ncclComm_t comm = (ncclComm_t)(int_val);
ncclWindow_t win;
ncclResult_t err2 = ncclCommWindowRegister(comm, ptr, size, &win, NCCL_WIN_COLL_SYMMETRIC);
return ptr;
} }
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) { void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
...@@ -27,8 +36,8 @@ void nccl_free_plug(void* ptr, size_t size, int device, void* stream) { ...@@ -27,8 +36,8 @@ void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
_allocator = None _allocator = None
_mem_pool = None _mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None _graph_pool_id = None
_cur_device = None
def is_symmetric_memory_enabled(): def is_symmetric_memory_enabled():
...@@ -41,7 +50,7 @@ def set_graph_pool_id(graph_pool_id): ...@@ -41,7 +50,7 @@ def set_graph_pool_id(graph_pool_id):
def get_nccl_mem_pool(): def get_nccl_mem_pool():
global _allocator, _mem_pool global _allocator, _mem_pool, _cur_device
if _mem_pool is None: if _mem_pool is None:
out_dir = tempfile.gettempdir() out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator" nccl_allocator_libname = "nccl_allocator"
...@@ -60,74 +69,67 @@ def get_nccl_mem_pool(): ...@@ -60,74 +69,67 @@ def get_nccl_mem_pool():
"nccl_free_plug", "nccl_free_plug",
).allocator() ).allocator()
_mem_pool = torch.cuda.MemPool(_allocator) _mem_pool = torch.cuda.MemPool(_allocator)
_cur_device = torch.cuda.current_device()
return _mem_pool return _mem_pool
class use_symmetric_memory: class use_symmetric_memory:
"""
Context manager for using symmetric memory with pynccl.
To Utilize the symmetric memory feature in NCCL, the buffers need to be allocated
by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce
this context manager. All tensors created under this context will be correctly
allocated and registered with a custom allocator.
In addition, developers need to manually tag the tensors that will be used as the input/output
of NCCL collectives with `tag(tensor)`.
"""
def __init__(self, group_coordinator: GroupCoordinator): def __init__(self, group_coordinator: GroupCoordinator):
if not is_symmetric_memory_enabled(): self.enabled = is_symmetric_memory_enabled()
self.group_coordinator = None
self._mem_pool_ctx = None if not self.enabled:
self.is_graph_capture = None return
self.device = None
self.pre_2_8_0 = None self.group_coordinator = group_coordinator
else: self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.group_coordinator = group_coordinator self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
def __enter__(self): def __enter__(self):
if not is_symmetric_memory_enabled(): if not self.enabled:
return self return self
assert ( assert (
self.group_coordinator.pynccl_comm is not None self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'" ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
assert (
self.group_coordinator.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
if self.is_graph_capture: if self.is_graph_capture:
assert ( assert (
_graph_pool_id is not None _graph_pool_id is not None
), "graph_pool_id is not set under graph capture" ), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph # Pause graph memory pool to use symmetric memory with cuda graph
if self.pre_2_8_0: torch._C._cuda_endAllocateToPool(_cur_device, _graph_pool_id)
torch._C._cuda_endAllocateCurrentStreamToPool(
self.device, _graph_pool_id
)
else:
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
self._mem_pool_ctx.__enter__() self._mem_pool_ctx.__enter__()
return self
def tag(self, tensor: torch.Tensor): # Set the env var to pass this argument to the C functions.
if not is_symmetric_memory_enabled(): os.environ["SGLANG_TMP_NCCL_COMM_VALUE"] = str(
return self.group_coordinator.pynccl_comm.comm.value
tensor.symmetric_memory = True )
return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if not is_symmetric_memory_enabled(): if not self.enabled:
return return
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
for segment in get_nccl_mem_pool().snapshot():
if segment["address"] not in _registered_base_addrs:
if segment["stream"] == 0 and self.pre_2_8_0:
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
# See https://github.com/pytorch/pytorch/issues/152861
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
continue
self.group_coordinator.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"]
)
_registered_base_addrs.add(segment["address"])
if self.is_graph_capture: if self.is_graph_capture:
if self.pre_2_8_0: torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id)
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
else: def tag(self, tensor: torch.Tensor):
torch._C._cuda_beginAllocateCurrentThreadToPool( if not self.enabled:
self.device, _graph_pool_id return
)
tensor.symmetric_memory = True
...@@ -43,6 +43,7 @@ from sglang.srt.environ import envs ...@@ -43,6 +43,7 @@ from sglang.srt.environ import envs
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
get_current_device_stream_fast,
get_int_env_var, get_int_env_var,
get_local_ip_auto, get_local_ip_auto,
is_cpu, is_cpu,
...@@ -466,7 +467,7 @@ class GroupCoordinator: ...@@ -466,7 +467,7 @@ class GroupCoordinator:
# ensure all initialization operations complete before attempting to # ensure all initialization operations complete before attempting to
# capture the graph on another stream # capture the graph on another stream
curr_stream = self.device_module.current_stream() curr_stream = get_current_device_stream_fast()
if curr_stream != stream: if curr_stream != stream:
stream.wait_stream(curr_stream) stream.wait_stream(curr_stream)
...@@ -500,7 +501,7 @@ class GroupCoordinator: ...@@ -500,7 +501,7 @@ class GroupCoordinator:
maybe_pynccl_context = nullcontext() maybe_pynccl_context = nullcontext()
else: else:
maybe_pynccl_context = pynccl_comm.change_state( maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream() enable=True, stream=get_current_device_stream_fast()
) )
pymscclpp_comm = self.pymscclpp_comm pymscclpp_comm = self.pymscclpp_comm
...@@ -551,13 +552,9 @@ class GroupCoordinator: ...@@ -551,13 +552,9 @@ class GroupCoordinator:
if self.npu_communicator is not None and not self.npu_communicator.disabled: if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_) return self.npu_communicator.all_reduce(input_)
if ( if self.pynccl_comm is not None and getattr(input_, "symmetric_memory", False):
self.pynccl_comm is not None
and hasattr(input_, "symmetric_memory")
and input_.symmetric_memory
):
with self.pynccl_comm.change_state( with self.pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream() enable=True, stream=get_current_device_stream_fast()
): ):
self.pynccl_comm.all_reduce(input_) self.pynccl_comm.all_reduce(input_)
return input_ return input_
...@@ -658,7 +655,7 @@ class GroupCoordinator: ...@@ -658,7 +655,7 @@ class GroupCoordinator:
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state( with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream() enable=True, stream=get_current_device_stream_fast()
): ):
assert ( assert (
pynccl_comm is not None and not pynccl_comm.disabled pynccl_comm is not None and not pynccl_comm.disabled
...@@ -784,7 +781,7 @@ class GroupCoordinator: ...@@ -784,7 +781,7 @@ class GroupCoordinator:
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state( with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream() enable=True, stream=get_current_device_stream_fast()
): ):
assert ( assert (
pynccl_comm is not None and not pynccl_comm.disabled pynccl_comm is not None and not pynccl_comm.disabled
......
...@@ -677,10 +677,16 @@ class Engine(EngineBase): ...@@ -677,10 +677,16 @@ class Engine(EngineBase):
def _set_envs_and_config(server_args: ServerArgs): def _set_envs_and_config(server_args: ServerArgs):
# Set global environments # Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
if "NCCL_CUMEM_ENABLE" not in os.environ: if "NCCL_CUMEM_ENABLE" not in os.environ or server_args.enable_symm_mem:
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem)) os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
if not server_args.enable_symm_mem: if (
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) "NCCL_NVLS_ENABLE" not in os.environ
or server_args.enable_nccl_nvls
or server_args.enable_symm_mem
):
os.environ["NCCL_NVLS_ENABLE"] = str(
int(server_args.enable_nccl_nvls or server_args.enable_symm_mem)
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
os.environ["CUDA_MODULE_LOADING"] = "AUTO" os.environ["CUDA_MODULE_LOADING"] = "AUTO"
......
...@@ -13,7 +13,7 @@ from sglang.srt.distributed import ( ...@@ -13,7 +13,7 @@ from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state, get_tp_group,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
...@@ -1372,7 +1372,7 @@ class RowParallelLinear(LinearBase): ...@@ -1372,7 +1372,7 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: with use_symmetric_memory(get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel) sm.tag(output_parallel)
......
"""CUTLASS based Fused MoE kernels.""" """CUTLASS based Fused MoE kernels."""
from typing import Optional
import torch import torch
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
...@@ -40,6 +42,7 @@ def cutlass_fused_experts_fp8( ...@@ -40,6 +42,7 @@ def cutlass_fused_experts_fp8(
problem_sizes1: torch.Tensor, problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor, problem_sizes2: torch.Tensor,
use_fp8_blockscale: bool = True, use_fp8_blockscale: bool = True,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations. """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.
...@@ -200,9 +203,11 @@ def cutlass_fused_experts_fp8( ...@@ -200,9 +203,11 @@ def cutlass_fused_experts_fp8(
workspace, workspace,
) )
result = torch.empty((m, k), device=device, dtype=out_dtype) if output is None:
apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype)) output = torch.empty((m, k), device=device, dtype=out_dtype)
return result
apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype))
return output
FLOAT4_E2M1_MAX = 6.0 FLOAT4_E2M1_MAX = 6.0
......
...@@ -14,6 +14,9 @@ from sglang.srt.distributed import ( ...@@ -14,6 +14,9 @@ from sglang.srt.distributed import (
get_tp_group, get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
MoeRunnerConfig, MoeRunnerConfig,
...@@ -55,11 +58,6 @@ from sglang.srt.utils import ( ...@@ -55,11 +58,6 @@ from sglang.srt.utils import (
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import RoutingMethodType, fp4_quantize from flashinfer import RoutingMethodType, fp4_quantize
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
# Try to import FP4 TRTLLM function if flashinfer is available # Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None trtllm_fp4_block_scale_moe = None
if should_use_flashinfer_trtllm_moe(): if should_use_flashinfer_trtllm_moe():
...@@ -68,6 +66,10 @@ if should_use_flashinfer_trtllm_moe(): ...@@ -68,6 +66,10 @@ if should_use_flashinfer_trtllm_moe():
except ImportError: except ImportError:
trtllm_fp4_block_scale_moe = None trtllm_fp4_block_scale_moe = None
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -839,12 +841,16 @@ class FusedMoE(torch.nn.Module): ...@@ -839,12 +841,16 @@ class FusedMoE(torch.nn.Module):
dispatch_output=dispatch_output, dispatch_output=dispatch_output,
**kwargs, **kwargs,
) )
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
# TODO: should we add some conditions here? with use_symmetric_memory(get_tp_group()) as sm:
final_hidden_states = final_hidden_states[ final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
..., :origin_hidden_states_dim
].contiguous() # TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim
].contiguous()
sm.tag(final_hidden_states)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
...@@ -980,6 +986,11 @@ class FlashInferFusedMoE(FusedMoE): ...@@ -980,6 +986,11 @@ class FlashInferFusedMoE(FusedMoE):
), ),
) )
# NOTE for symmetric memory tagging:
# We do not create the context in this function.
# Instead, we create the context and tagging inside each FusedMoEMethodBase
# This can allow fine-grained tagging.
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
...@@ -1040,6 +1051,10 @@ class FlashInferFP4MoE(FusedMoE): ...@@ -1040,6 +1051,10 @@ class FlashInferFP4MoE(FusedMoE):
router_logits = router_logits.to(torch.float32) router_logits = router_logits.to(torch.float32)
with use_symmetric_memory(get_tp_group()) as sm:
symm_output = torch.empty_like(hidden_states)
sm.tag(symm_output)
result = trtllm_fp4_block_scale_moe( result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits, routing_logits=router_logits,
routing_bias=topk_config.correction_bias.to(hidden_states.dtype), routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
...@@ -1072,6 +1087,7 @@ class FlashInferFP4MoE(FusedMoE): ...@@ -1072,6 +1087,7 @@ class FlashInferFP4MoE(FusedMoE):
tile_tokens_dim=None, tile_tokens_dim=None,
routing_method_type=RoutingMethodType.DeepSeekV3, routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True, do_finalize=True,
output=symm_output,
)[0] )[0]
return result return result
...@@ -28,7 +28,10 @@ except ImportError: ...@@ -28,7 +28,10 @@ except ImportError:
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size, get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
...@@ -1025,6 +1028,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1025,6 +1028,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self._should_use_cutlass_fused_experts(): if self._should_use_cutlass_fused_experts():
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
with use_symmetric_memory(get_tp_group()) as sm:
symm_output = torch.empty_like(x)
sm.tag(symm_output)
topk_weights, topk_ids, _ = dispatch_output.topk_output topk_weights, topk_ids, _ = dispatch_output.topk_output
output = cutlass_fused_experts_fp8( output = cutlass_fused_experts_fp8(
x, x,
...@@ -1048,6 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1048,6 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.problem_sizes1, self.problem_sizes1,
self.problem_sizes2, self.problem_sizes2,
use_fp8_blockscale=True, use_fp8_blockscale=True,
output=symm_output,
) )
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
...@@ -1211,31 +1219,38 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1211,31 +1219,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
else topk_config.correction_bias.to(x.dtype) else topk_config.correction_bias.to(x.dtype)
) )
return trtllm_fp8_block_scale_moe( with use_symmetric_memory(get_tp_group()) as sm:
routing_logits=router_logits.to(torch.float32), # FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
routing_bias=correction_bias, # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
hidden_states=a_q, # so we put the whole function under the ``use_symmetric_memory`` context manager.
hidden_states_scale=a_sf_t, # If the bug is fixed, we can only put the output tensor allocation under the context manager.
gemm1_weights=layer.w13_weight, output = trtllm_fp8_block_scale_moe(
gemm1_weights_scale=layer.w13_weight_scale_inv, routing_logits=router_logits.to(torch.float32),
gemm2_weights=layer.w2_weight, routing_bias=correction_bias,
gemm2_weights_scale=layer.w2_weight_scale_inv, hidden_states=a_q,
num_experts=layer.num_experts, hidden_states_scale=a_sf_t,
top_k=topk_config.top_k, gemm1_weights=layer.w13_weight,
n_group=topk_config.num_expert_group, gemm1_weights_scale=layer.w13_weight_scale_inv,
topk_group=topk_config.topk_group, gemm2_weights=layer.w2_weight,
intermediate_size=layer.w2_weight.shape[2], gemm2_weights_scale=layer.w2_weight_scale_inv,
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, num_experts=layer.num_experts,
local_num_experts=layer.num_local_experts, top_k=topk_config.top_k,
routed_scaling_factor=( n_group=topk_config.num_expert_group,
routed_scaling_factor if routed_scaling_factor is not None else 1.0 topk_group=topk_config.topk_group,
), intermediate_size=layer.w2_weight.shape[2],
tile_tokens_dim=get_tile_tokens_dim( local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
x.shape[0], topk_config.top_k, layer.num_experts local_num_experts=layer.num_local_experts,
), routed_scaling_factor=(
routing_method_type=2, # DeepSeek-styled routing method routed_scaling_factor if routed_scaling_factor is not None else 1.0
use_shuffled_weight=False, ),
) tile_tokens_dim=get_tile_tokens_dim(
x.shape[0], topk_config.top_k, layer.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
sm.tag(output)
return output
def maybe_apply_hip_fused_experts( def maybe_apply_hip_fused_experts(
self, self,
......
...@@ -8,6 +8,9 @@ import torch ...@@ -8,6 +8,9 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
MoeRunner, MoeRunner,
...@@ -659,29 +662,37 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -659,29 +662,37 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
None if correction_bias is None else correction_bias.to(torch.bfloat16) None if correction_bias is None else correction_bias.to(torch.bfloat16)
) )
output = trtllm_fp8_per_tensor_scale_moe( with use_symmetric_memory(get_tp_group()) as sm:
routing_logits=routing_logits_cast, # FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
routing_bias=routing_bias_cast, # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
hidden_states=x_fp8, # so we put the whole function under the ``use_symmetric_memory`` context manager.
gemm1_weights=layer.w13_weight, # If the bug is fixed, we can only put the output tensor allocation under the context manager.
output1_scales_scalar=layer.output1_scales_scalar, output = trtllm_fp8_per_tensor_scale_moe(
output1_scales_gate_scalar=layer.output1_scales_gate_scalar, routing_logits=routing_logits_cast,
gemm2_weights=layer.w2_weight, routing_bias=routing_bias_cast,
output2_scales_scalar=layer.output2_scales_scalar, hidden_states=x_fp8,
num_experts=layer.num_experts, gemm1_weights=layer.w13_weight,
top_k=topk_config.top_k, output1_scales_scalar=layer.output1_scales_scalar,
n_group=0, output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
topk_group=0, gemm2_weights=layer.w2_weight,
intermediate_size=layer.w2_weight.shape[2], output2_scales_scalar=layer.output2_scales_scalar,
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, num_experts=layer.num_experts,
local_num_experts=layer.num_local_experts, top_k=topk_config.top_k,
routed_scaling_factor=( n_group=0,
routed_scaling_factor if routed_scaling_factor is not None else 1.0 topk_group=0,
), intermediate_size=layer.w2_weight.shape[2],
use_routing_scales_on_input=use_routing_scales_on_input, local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
tile_tokens_dim=8, # TODO(brayden): use the FI tile calculation local_num_experts=layer.num_local_experts,
routing_method_type=routing_method_type, routed_scaling_factor=(
) routed_scaling_factor
if routed_scaling_factor is not None
else 1.0
),
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=8, # TODO(brayden): use the FI tile calculation
routing_method_type=routing_method_type,
)
sm.tag(output)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
...@@ -1587,6 +1598,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1587,6 +1598,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) )
x_sf = nvfp4_block_scale_interleave(x_sf) x_sf = nvfp4_block_scale_interleave(x_sf)
with use_symmetric_memory(get_tp_group()) as sm:
symm_output = torch.empty(
x.shape[0], x.shape[1], dtype=output_dtype, device=x.device
)
sm.tag(symm_output)
output = flashinfer_cutlass_fused_moe( output = flashinfer_cutlass_fused_moe(
input=x, input=x,
token_selected_experts=topk_ids.to(torch.int), token_selected_experts=topk_ids.to(torch.int),
...@@ -1608,6 +1625,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1608,6 +1625,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_size=layer.moe_tp_size, tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank, tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]), tune_max_num_tokens=next_power_of_2(x.shape[0]),
output=symm_output,
)[0] )[0]
if should_use_flashinfer_cutlass_moe_fp4_allgather(): if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output output, global_output = get_local_dp_buffer(), output
......
...@@ -22,6 +22,10 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -22,6 +22,10 @@ from typing import TYPE_CHECKING, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend from sglang.srt.layers.moe.utils import get_moe_runner_backend
...@@ -70,14 +74,14 @@ _is_hip = is_hip() ...@@ -70,14 +74,14 @@ _is_hip = is_hip()
if _is_hip: if _is_hip:
# import aiter # import aiter
try: try:
from aiter import ActivationType, QuantType, dtypes from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe from aiter.fused_moe import fused_moe
from aiter.ops.triton.quant import dynamic_mxfp4_quant from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.utility.fp4_utils import e8m0_shuffle from aiter.utility.fp4_utils import e8m0_shuffle
except ImportError as err: except ImportError as err:
ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = ( ActivationType = QuantType = fused_moe = dynamic_mxfp4_quant = e8m0_shuffle = (
e8m0_shuffle err
) = err )
def _swizzle_mxfp4(quant_tensor, scale, num_warps): def _swizzle_mxfp4(quant_tensor, scale, num_warps):
...@@ -606,8 +610,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -606,8 +610,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x = dispatch_output.hidden_states x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if self.use_flashinfer: if self.use_flashinfer:
# When bf16 mode is enabled, we don't need to quantize the input, # When bf16 mode is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations, # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
...@@ -630,7 +632,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -630,7 +632,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size) x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
else: else:
raise NotImplementedError raise NotImplementedError()
assert x_quant.shape[-1] == self.hidden_size assert x_quant.shape[-1] == self.hidden_size
assert TopKOutputChecker.format_is_bypassed(topk_output) assert TopKOutputChecker.format_is_bypassed(topk_output)
...@@ -638,6 +640,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -638,6 +640,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
top_k = topk_output.topk_config.top_k top_k = topk_output.topk_config.top_k
router_logits = topk_output.router_logits router_logits = topk_output.router_logits
with use_symmetric_memory(get_tp_group()) as sm:
symm_output = torch.empty_like(x)
sm.tag(symm_output)
trtllm_gen_output = trtllm_fp4_block_scale_moe( trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16), router_logits.to(torch.bfloat16),
None, # routing_bias None, # routing_bias
...@@ -666,6 +672,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -666,6 +672,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # tile_tokens_dim None, # tile_tokens_dim
1, # routing_method_type, renormalize 1, # routing_method_type, renormalize
True, # do finalize True, # do finalize
output=symm_output,
)[0] )[0]
return StandardCombineInput(hidden_states=trtllm_gen_output) return StandardCombineInput(hidden_states=trtllm_gen_output)
......
...@@ -11,7 +11,7 @@ from sglang.srt.distributed import ( ...@@ -11,7 +11,7 @@ from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state, get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import ( from sglang.srt.distributed.device_communicators.pynccl_allocator import (
...@@ -473,7 +473,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -473,7 +473,7 @@ class VocabParallelEmbedding(torch.nn.Module):
else: else:
masked_input = input_ masked_input = input_
# Get the embeddings. # Get the embeddings.
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: with use_symmetric_memory(get_tp_group()) as sm:
output_parallel = self.quant_method.embedding(self, masked_input.long()) output_parallel = self.quant_method.embedding(self, masked_input.long())
sm.tag(output_parallel) sm.tag(output_parallel)
# Mask the output embedding. # Mask the output embedding.
......
...@@ -39,12 +39,8 @@ from sglang.srt.distributed import ( ...@@ -39,12 +39,8 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_pp_group, get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
...@@ -758,12 +754,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -758,12 +754,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: final_hidden_states += shared_output
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if ( if (
self.tp_size > 1 self.tp_size > 1
and not should_allreduce_fusion and not should_allreduce_fusion
...@@ -822,11 +813,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -822,11 +813,8 @@ class DeepseekV2MoE(nn.Module):
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: final_hidden_states += shared_output
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if ( if (
self.tp_size > 1 self.tp_size > 1
and not should_allreduce_fusion and not should_allreduce_fusion
......
...@@ -3174,7 +3174,7 @@ class ServerArgs: ...@@ -3174,7 +3174,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--enable-torch-symm-mem", "--enable-torch-symm-mem",
action="store_true", action="store_true",
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.", help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM100 supports world size 6, 8.",
) )
parser.add_argument( parser.add_argument(
"--disable-overlap-schedule", "--disable-overlap-schedule",
......
...@@ -3605,3 +3605,13 @@ def calc_diff(x, y): ...@@ -3605,3 +3605,13 @@ def calc_diff(x, y):
denominator = (x * x + y * y).sum() denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator sim = 2 * (x * y).sum() / denominator
return 1 - sim return 1 - sim
cached_device_index = -1
def get_current_device_stream_fast():
global cached_device_index
if cached_device_index == -1:
cached_device_index = torch.get_device_module().current_device()
return torch.get_device_module().current_stream(cached_device_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment