Unverified Commit eff4eb3f authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)

parent 87dab548
...@@ -148,7 +148,11 @@ class PyNcclCommunicator: ...@@ -148,7 +148,11 @@ class PyNcclCommunicator:
) )
def all_gather( def all_gather(
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None,
sizes: Optional[list[int]] = None,
): ):
if self.disabled: if self.disabled:
return return
...@@ -161,14 +165,33 @@ class PyNcclCommunicator: ...@@ -161,14 +165,33 @@ class PyNcclCommunicator:
) )
if stream is None: if stream is None:
stream = self.stream stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()), if sizes is not None:
buffer_type(output_tensor.data_ptr()), split_offset = 0
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.nccl.ncclGroupStart()
self.comm, for root, split_size in enumerate(sizes):
cudaStream_t(stream.cuda_stream), dst_slice = output_tensor[split_offset : split_offset + split_size]
) self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def reduce_scatter( def reduce_scatter(
self, self,
...@@ -176,6 +199,7 @@ class PyNcclCommunicator: ...@@ -176,6 +199,7 @@ class PyNcclCommunicator:
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
stream=None, stream=None,
sizes: Optional[list[int]] = None,
): ):
if self.disabled: if self.disabled:
return return
...@@ -188,15 +212,35 @@ class PyNcclCommunicator: ...@@ -188,15 +212,35 @@ class PyNcclCommunicator:
) )
if stream is None: if stream is None:
stream = self.stream stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()), if sizes is not None:
buffer_type(output_tensor.data_ptr()), split_offset = 0
output_tensor.numel(), self.nccl.ncclGroupStart()
ncclDataTypeEnum.from_torch(input_tensor.dtype), for root, split_size in enumerate(sizes):
ncclRedOpTypeEnum.from_torch(op), chunk = input_tensor[split_offset : split_offset + split_size, ...]
self.comm,
cudaStream_t(stream.cuda_stream), self.nccl.ncclReduce(
) buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()),
chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def send(self, tensor: torch.Tensor, dst: int, stream=None): def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled: if self.disabled:
...@@ -266,6 +310,12 @@ class PyNcclCommunicator: ...@@ -266,6 +310,12 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window): def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window) return self.nccl.ncclCommWindowDeregister(self.comm, window)
def group_start(self):
self.nccl.ncclGroupStart()
def group_end(self):
self.nccl.ncclGroupEnd()
@contextmanager @contextmanager
def change_state( def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
......
...@@ -206,6 +206,26 @@ class NCCLLibrary: ...@@ -206,6 +206,26 @@ class NCCLLibrary:
cudaStream_t, cudaStream_t,
], ],
), ),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclReduce",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduceScatter( # ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count, # const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
...@@ -278,6 +298,10 @@ class NCCLLibrary: ...@@ -278,6 +298,10 @@ class NCCLLibrary:
# it is better not to call it at all. # it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm); # ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
] ]
exported_functions_symm_mem = [ exported_functions_symm_mem = [
...@@ -400,6 +424,28 @@ class NCCLLibrary: ...@@ -400,6 +424,28 @@ class NCCLLibrary:
) )
) )
def ncclReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclReduce"](
sendbuff, recvbuff, count, datatype, op, root, comm, stream
)
)
def ncclReduceScatter( def ncclReduceScatter(
self, self,
sendbuff: buffer_type, sendbuff: buffer_type,
...@@ -499,6 +545,12 @@ class NCCLLibrary: ...@@ -499,6 +545,12 @@ class NCCLLibrary:
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None: def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
__all__ = [ __all__ = [
"NCCLLibrary", "NCCLLibrary",
......
...@@ -583,6 +583,39 @@ class GroupCoordinator: ...@@ -583,6 +583,39 @@ class GroupCoordinator:
torch.distributed.reduce_scatter(output, input_list, group=self.device_group) torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
return output return output
def reduce_scatterv(
self,
input_: torch.Tensor,
output: Optional[torch.Tensor] = None,
sizes: Optional[List[int]] = None,
) -> torch.Tensor:
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for reduce_scatterv"
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_.shape[0] % world_size == 0
chunk_size = input_.shape[0] // world_size
output_shape = (chunk_size,) + input_.shape[1:]
if output is None:
output = torch.empty(
output_shape, dtype=input_.dtype, device=input_.device
)
else:
assert output.shape == output_shape
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
return output
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled: if pynccl_comm is not None and not pynccl_comm.disabled:
...@@ -673,6 +706,54 @@ class GroupCoordinator: ...@@ -673,6 +706,54 @@ class GroupCoordinator:
) )
return output_tensor return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, List[torch.Tensor]],
sizes: Optional[List[int]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Supports varying sizes per rank and input tensor list.
`sizes`: a list of len(world_size) with the number of items per rank to gather.
"""
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv"
def _all_gather_single(
input_: torch.Tensor, sizes: Optional[List[int]] = None
):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[0] == sizes[self.rank_in_group]
output_size = (sum(sizes),) + input_size[1:]
# 'sizes' is not needed if all inputs in the same group have the same shape
if all(s == sizes[0] for s in sizes):
sizes = None
else:
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
return output_tensor
if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
output_list = []
pynccl_comm.group_start()
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
pynccl_comm.group_end()
return output_list
def gather( def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1 self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
......
...@@ -35,7 +35,10 @@ from sglang.srt.layers.dp_attention import ( ...@@ -35,7 +35,10 @@ from sglang.srt.layers.dp_attention import (
get_global_dp_buffer, get_global_dp_buffer,
get_local_dp_buffer, get_local_dp_buffer,
) )
from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -112,7 +115,11 @@ class LayerScatterModes: ...@@ -112,7 +115,11 @@ class LayerScatterModes:
if context.is_layer_sparse: if context.is_layer_sparse:
return ( return (
ScatterMode.SCATTERED ScatterMode.SCATTERED
if not get_moe_a2a_backend().is_none() if (
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
not get_moe_a2a_backend().is_none()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
)
else ScatterMode.FULL else ScatterMode.FULL
) )
else: else:
......
...@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper: ...@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
_device: torch.device _device: torch.device
_global_dp_buffer_len: int _global_dp_buffer_len: int
_local_dp_buffer_len: int _local_dp_buffer_len: int
_global_num_tokens: Optional[List[int]]
@classmethod @classmethod
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device): def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
...@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper: ...@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
cls._device = device cls._device = device
@classmethod @classmethod
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int): def set_dp_buffer_len(
cls,
global_dp_buffer_len: int,
local_dp_buffer_len: int,
global_num_tokens: Optional[List[int]] = None,
):
cls._global_dp_buffer_len = global_dp_buffer_len cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len cls._local_dp_buffer_len = local_dp_buffer_len
cls._global_num_tokens = global_num_tokens
@classmethod @classmethod
def get_global_dp_buffer(cls) -> torch.Tensor: def get_global_dp_buffer(cls) -> torch.Tensor:
...@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper: ...@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
def get_local_dp_buffer_len(cls) -> int: def get_local_dp_buffer_len(cls) -> int:
return cls._local_dp_buffer_len return cls._local_dp_buffer_len
@classmethod
def get_dp_global_num_tokens(cls) -> List[int]:
return cls._global_num_tokens
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int): def set_dp_buffer_len(
global_dp_buffer_len: int,
local_dp_buffer_len: int,
global_num_tokens: Optional[List[int]] = None,
):
_DpGatheredBufferWrapper.set_dp_buffer_len( _DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
) )
...@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int: ...@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_local_dp_buffer_len() return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
def get_dp_global_num_tokens() -> List[int]:
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention: if not enable_dp_attention:
return tp_rank, tp_size, 0 return tp_rank, tp_size, 0
......
...@@ -191,7 +191,11 @@ class LogitsMetadata: ...@@ -191,7 +191,11 @@ class LogitsMetadata:
else: else:
self.global_dp_buffer_len = self.global_dp_buffer_len self.global_dp_buffer_len = self.global_dp_buffer_len
set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens) set_dp_buffer_len(
self.global_dp_buffer_len,
self.dp_local_num_tokens,
self.global_num_tokens_for_logprob_cpu,
)
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
......
...@@ -10,6 +10,7 @@ from sglang.srt.layers.moe.utils import ( ...@@ -10,6 +10,7 @@ from sglang.srt.layers.moe.utils import (
get_tbo_token_distribution_threshold, get_tbo_token_distribution_threshold,
initialize_moe_config, initialize_moe_config,
is_tbo_enabled, is_tbo_enabled,
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
) )
...@@ -23,6 +24,7 @@ __all__ = [ ...@@ -23,6 +24,7 @@ __all__ = [
"get_moe_runner_backend", "get_moe_runner_backend",
"get_deepep_mode", "get_deepep_mode",
"should_use_flashinfer_trtllm_moe", "should_use_flashinfer_trtllm_moe",
"should_use_flashinfer_cutlass_moe_fp4_allgather",
"is_tbo_enabled", "is_tbo_enabled",
"get_tbo_token_distribution_threshold", "get_tbo_token_distribution_threshold",
"get_deepep_config", "get_deepep_config",
......
...@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
...@@ -621,9 +622,7 @@ class FusedMoE(torch.nn.Module): ...@@ -621,9 +622,7 @@ class FusedMoE(torch.nn.Module):
if "ModelOpt" in self.quant_method.__class__.__name__: if "ModelOpt" in self.quant_method.__class__.__name__:
# Determine per-tensor weight scale patterns based on variant # Determine per-tensor weight scale patterns based on variant
is_fp4_variant = ( is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions = ( per_tensor_conditions = (
......
...@@ -327,6 +327,13 @@ class TopK(CustomOp): ...@@ -327,6 +327,13 @@ class TopK(CustomOp):
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
def empty_topk_output(self, device: torch.device) -> TopKOutput:
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
return StandardTopKOutput(topk_weights, topk_idx, router_logits)
# ------------------------------- TopK implementation ------------------------------------- # ------------------------------- TopK implementation -------------------------------------
......
...@@ -7,6 +7,11 @@ from typing import TYPE_CHECKING, Optional ...@@ -7,6 +7,11 @@ from typing import TYPE_CHECKING, Optional
from packaging import version as pkg_version from packaging import version as pkg_version
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.utils import logger from sglang.srt.utils import logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -99,6 +104,7 @@ DEEPEP_MODE: Optional[DeepEPMode] = None ...@@ -99,6 +104,7 @@ DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None IS_TBO_ENABLED: Optional[bool] = None
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
DEEPEP_CONFIG: Optional[str] = None DEEPEP_CONFIG: Optional[str] = None
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
def initialize_moe_config(server_args: ServerArgs): def initialize_moe_config(server_args: ServerArgs):
...@@ -108,6 +114,7 @@ def initialize_moe_config(server_args: ServerArgs): ...@@ -108,6 +114,7 @@ def initialize_moe_config(server_args: ServerArgs):
global DEEPEP_CONFIG global DEEPEP_CONFIG
global IS_TBO_ENABLED global IS_TBO_ENABLED
global TBO_TOKEN_DISTRIBUTION_THRESHOLD global TBO_TOKEN_DISTRIBUTION_THRESHOLD
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend) MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend) MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
...@@ -115,6 +122,9 @@ def initialize_moe_config(server_args: ServerArgs): ...@@ -115,6 +122,9 @@ def initialize_moe_config(server_args: ServerArgs):
DEEPEP_CONFIG = server_args.deepep_config or "" DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap IS_TBO_ENABLED = server_args.enable_two_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
)
def get_moe_a2a_backend() -> MoeA2ABackend: def get_moe_a2a_backend() -> MoeA2ABackend:
...@@ -175,3 +185,16 @@ def should_use_flashinfer_trtllm_moe(): ...@@ -175,3 +185,16 @@ def should_use_flashinfer_trtllm_moe():
>= pkg_version.parse("0.2.9rc1") >= pkg_version.parse("0.2.9rc1")
) )
return result return result
@lru_cache(maxsize=1)
def should_use_flashinfer_cutlass_moe_fp4_allgather():
"""
Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving.
"""
return (
not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
and get_moe_runner_backend().is_flashinfer_cutlass()
and is_dp_attention_enabled()
and get_moe_expert_parallel_world_size() == get_attention_dp_size()
)
...@@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional ...@@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
from sglang.srt.layers.moe import (
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
), "apply_router_weight_on_input is not supported for Flashinfer" ), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint # and fp4 quantized weights loaded from the checkpoint
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output_dtype = x.dtype
x_sf = None
if should_use_flashinfer_cutlass_moe_fp4_allgather():
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
# Quantize before comm, swizzle after.
if x.shape[0] > 0:
x, x_sf = fp4_quantize(
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
)
else:
x_col = x.shape[1]
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
x_sf = torch.zeros(
0, x_col // 16, dtype=torch.uint8, device=x.device
)
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
)
x_sf = nvfp4_block_scale_interleave(x_sf)
output = flashinfer_cutlass_fused_moe( output = flashinfer_cutlass_fused_moe(
x, input=x,
topk_ids.to(torch.int), token_selected_experts=topk_ids.to(torch.int),
topk_weights, token_final_scales=topk_weights,
layer.w13_weight.view(torch.long), fc1_expert_weights=layer.w13_weight.view(torch.long),
layer.w2_weight.view(torch.long), fc2_expert_weights=layer.w2_weight.view(torch.long),
x.dtype, output_dtype=output_dtype,
input_sf=x_sf,
quant_scales=[ quant_scales=[
layer.w13_input_scale_quant, layer.w13_input_scale_quant,
layer.w13_blockscale_swizzled.view(torch.int32), layer.w13_blockscale_swizzled.view(torch.int32),
...@@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)[0] )[0]
if moe_runner_config.routed_scaling_factor is not None: if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor output *= moe_runner_config.routed_scaling_factor
if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output
get_tp_group().reduce_scatterv(
global_output, output=output, sizes=get_dp_global_num_tokens()
)
return output return output
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
......
...@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"chunked_prefill_size", "chunked_prefill_size",
"device", "device",
"disable_chunked_prefix_cache", "disable_chunked_prefix_cache",
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache", "disable_radix_cache",
"enable_dp_lm_head", "enable_dp_lm_head",
"enable_flashinfer_allreduce_fusion", "enable_flashinfer_allreduce_fusion",
......
...@@ -649,7 +649,7 @@ class ForwardBatch: ...@@ -649,7 +649,7 @@ class ForwardBatch:
num_tokens = global_num_tokens[0] num_tokens = global_num_tokens[0]
self.global_dp_buffer_len = buffer_len self.global_dp_buffer_len = buffer_len
set_dp_buffer_len(buffer_len, num_tokens) set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
bs = self.batch_size bs = self.batch_size
......
...@@ -60,7 +60,11 @@ from sglang.srt.layers.linear import ( ...@@ -60,7 +60,11 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
...@@ -343,7 +347,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -343,7 +347,7 @@ class DeepseekV2MoE(nn.Module):
self.shared_experts_weight_block_size = None self.shared_experts_weight_block_size = None
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0: if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
# disable tp for shared experts when enable deepep moe # disable tp for shared experts when enable deepep moe, or with fp4 allgather
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
...@@ -354,6 +358,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -354,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
**( **(
dict(tp_rank=0, tp_size=1) dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep() if get_moe_a2a_backend().is_deepep()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
else {} else {}
), ),
) )
...@@ -433,14 +438,19 @@ class DeepseekV2MoE(nn.Module): ...@@ -433,14 +438,19 @@ class DeepseekV2MoE(nn.Module):
if ( if (
self.alt_stream is not None self.alt_stream is not None
and self.num_fused_shared_experts == 0 and self.num_fused_shared_experts == 0
and hidden_states.shape[0] > 0
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
): ):
return self.forward_normal_dual_stream( return self.forward_normal_dual_stream(
hidden_states, should_allreduce_fusion, use_reduce_scatter hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
) )
else: else:
return self.forward_normal( return self.forward_normal(
hidden_states, should_allreduce_fusion, use_reduce_scatter hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
) )
else: else:
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
...@@ -471,7 +481,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -471,7 +481,12 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter: if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
...@@ -486,10 +501,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -486,10 +501,14 @@ class DeepseekV2MoE(nn.Module):
): ):
return self.forward_cpu(hidden_states, should_allreduce_fusion) return self.forward_cpu(hidden_states, should_allreduce_fusion)
shared_output = self._forward_shared_experts(hidden_states) if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts) shared_output = self._forward_shared_experts(hidden_states)
router_logits = self.gate(hidden_states) # router_logits: (num_tokens, n_experts)
topk_output = self.topk(hidden_states, router_logits) router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
else:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
final_hidden_states = self.experts(hidden_states, topk_output) final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda and not _use_aiter: if not _is_cuda and not _use_aiter:
...@@ -501,7 +520,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -501,7 +520,12 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter: if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
...@@ -580,11 +604,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -580,11 +604,8 @@ class DeepseekV2MoE(nn.Module):
), ),
) )
else: else:
topk_idx = torch.full( topk_weights, topk_idx, _ = self.topk.empty_topk_output(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
) )
final_hidden_states = self.experts( final_hidden_states = self.experts(
......
...@@ -84,6 +84,7 @@ class _StageExecutor: ...@@ -84,6 +84,7 @@ class _StageExecutor:
forward_batch: ForwardBatch = inputs["forward_batch"] forward_batch: ForwardBatch = inputs["forward_batch"]
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
self._local_dp_buffer_len = forward_batch.input_ids.shape[0] self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
self._global_num_tokens = forward_batch.global_num_tokens_cpu
def next(self): def next(self):
assert not self.done assert not self.done
...@@ -91,7 +92,11 @@ class _StageExecutor: ...@@ -91,7 +92,11 @@ class _StageExecutor:
stage = self._stages[self._index] stage = self._stages[self._index]
if self._global_dp_buffer_len is not None: if self._global_dp_buffer_len is not None:
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len) set_dp_buffer_len(
self._global_dp_buffer_len,
self._local_dp_buffer_len,
self._global_num_tokens,
)
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"): with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
for op in stage: for op in stage:
......
...@@ -230,6 +230,7 @@ class ServerArgs: ...@@ -230,6 +230,7 @@ class ServerArgs:
enable_cudagraph_gc: bool = False enable_cudagraph_gc: bool = False
enable_nccl_nvls: bool = False enable_nccl_nvls: bool = False
enable_symm_mem: bool = False enable_symm_mem: bool = False
disable_flashinfer_cutlass_moe_fp4_allgather: bool = False
enable_tokenizer_batch_encode: bool = False enable_tokenizer_batch_encode: bool = False
disable_outlines_disk_cache: bool = False disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
...@@ -1714,6 +1715,11 @@ class ServerArgs: ...@@ -1714,6 +1715,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable NCCL symmetric memory for fast collectives.", help="Enable NCCL symmetric memory for fast collectives.",
) )
parser.add_argument(
"--disable-flashinfer-cutlass-moe-fp4-allgather",
action="store_true",
help="Disables quantize before all-gather for flashinfer cutlass moe.",
)
parser.add_argument( parser.add_argument(
"--enable-tokenizer-batch-encode", "--enable-tokenizer-batch-encode",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment