"vscode:/vscode.git/clone" did not exist on "de34e15abbc068f608c8b152070e017f53b35f2f"
Unverified Commit eff4eb3f authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)

parent 87dab548
......@@ -148,7 +148,11 @@ class PyNcclCommunicator:
)
def all_gather(
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
return
......@@ -161,14 +165,33 @@ class PyNcclCommunicator:
)
if stream is None:
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
if sizes is not None:
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset : split_offset + split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def reduce_scatter(
self,
......@@ -176,6 +199,7 @@ class PyNcclCommunicator:
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
return
......@@ -188,15 +212,35 @@ class PyNcclCommunicator:
)
if stream is None:
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
if sizes is not None:
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset : split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()),
chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
......@@ -266,6 +310,12 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
def group_start(self):
self.nccl.ncclGroupStart()
def group_end(self):
self.nccl.ncclGroupEnd()
@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
......
......@@ -206,6 +206,26 @@ class NCCLLibrary:
cudaStream_t,
],
),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclReduce",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
......@@ -278,6 +298,10 @@ class NCCLLibrary:
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
]
exported_functions_symm_mem = [
......@@ -400,6 +424,28 @@ class NCCLLibrary:
)
)
def ncclReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclReduce"](
sendbuff, recvbuff, count, datatype, op, root, comm, stream
)
)
def ncclReduceScatter(
self,
sendbuff: buffer_type,
......@@ -499,6 +545,12 @@ class NCCLLibrary:
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
__all__ = [
"NCCLLibrary",
......
......@@ -583,6 +583,39 @@ class GroupCoordinator:
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
return output
def reduce_scatterv(
self,
input_: torch.Tensor,
output: Optional[torch.Tensor] = None,
sizes: Optional[List[int]] = None,
) -> torch.Tensor:
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for reduce_scatterv"
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_.shape[0] % world_size == 0
chunk_size = input_.shape[0] // world_size
output_shape = (chunk_size,) + input_.shape[1:]
if output is None:
output = torch.empty(
output_shape, dtype=input_.dtype, device=input_.device
)
else:
assert output.shape == output_shape
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
return output
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
......@@ -673,6 +706,54 @@ class GroupCoordinator:
)
return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, List[torch.Tensor]],
sizes: Optional[List[int]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Supports varying sizes per rank and input tensor list.
`sizes`: a list of len(world_size) with the number of items per rank to gather.
"""
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv"
def _all_gather_single(
input_: torch.Tensor, sizes: Optional[List[int]] = None
):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[0] == sizes[self.rank_in_group]
output_size = (sum(sizes),) + input_size[1:]
# 'sizes' is not needed if all inputs in the same group have the same shape
if all(s == sizes[0] for s in sizes):
sizes = None
else:
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
return output_tensor
if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
output_list = []
pynccl_comm.group_start()
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
pynccl_comm.group_end()
return output_list
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
......
......@@ -35,7 +35,10 @@ from sglang.srt.layers.dp_attention import (
get_global_dp_buffer,
get_local_dp_buffer,
)
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -112,7 +115,11 @@ class LayerScatterModes:
if context.is_layer_sparse:
return (
ScatterMode.SCATTERED
if not get_moe_a2a_backend().is_none()
if (
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
not get_moe_a2a_backend().is_none()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
)
else ScatterMode.FULL
)
else:
......
......@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
_device: torch.device
_global_dp_buffer_len: int
_local_dp_buffer_len: int
_global_num_tokens: Optional[List[int]]
@classmethod
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
......@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
cls._device = device
@classmethod
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
def set_dp_buffer_len(
cls,
global_dp_buffer_len: int,
local_dp_buffer_len: int,
global_num_tokens: Optional[List[int]] = None,
):
cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len
cls._global_num_tokens = global_num_tokens
@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
......@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
def get_local_dp_buffer_len(cls) -> int:
return cls._local_dp_buffer_len
@classmethod
def get_dp_global_num_tokens(cls) -> List[int]:
return cls._global_num_tokens
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
def set_dp_buffer_len(
global_dp_buffer_len: int,
local_dp_buffer_len: int,
global_num_tokens: Optional[List[int]] = None,
):
_DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
)
......@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
def get_dp_global_num_tokens() -> List[int]:
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
......
......@@ -191,7 +191,11 @@ class LogitsMetadata:
else:
self.global_dp_buffer_len = self.global_dp_buffer_len
set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
set_dp_buffer_len(
self.global_dp_buffer_len,
self.dp_local_num_tokens,
self.global_num_tokens_for_logprob_cpu,
)
class LogitsProcessor(nn.Module):
......
......@@ -10,6 +10,7 @@ from sglang.srt.layers.moe.utils import (
get_tbo_token_distribution_threshold,
initialize_moe_config,
is_tbo_enabled,
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
)
......@@ -23,6 +24,7 @@ __all__ = [
"get_moe_runner_backend",
"get_deepep_mode",
"should_use_flashinfer_trtllm_moe",
"should_use_flashinfer_cutlass_moe_fp4_allgather",
"is_tbo_enabled",
"get_tbo_token_distribution_threshold",
"get_deepep_config",
......
......@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
......@@ -621,9 +622,7 @@ class FusedMoE(torch.nn.Module):
if "ModelOpt" in self.quant_method.__class__.__name__:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant = (
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
)
is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions = (
......
......@@ -327,6 +327,13 @@ class TopK(CustomOp):
expert_location_dispatch_info=expert_location_dispatch_info,
)
def empty_topk_output(self, device: torch.device) -> TopKOutput:
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
return StandardTopKOutput(topk_weights, topk_idx, router_logits)
# ------------------------------- TopK implementation -------------------------------------
......
......@@ -7,6 +7,11 @@ from typing import TYPE_CHECKING, Optional
from packaging import version as pkg_version
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.utils import logger
if TYPE_CHECKING:
......@@ -99,6 +104,7 @@ DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
DEEPEP_CONFIG: Optional[str] = None
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
def initialize_moe_config(server_args: ServerArgs):
......@@ -108,6 +114,7 @@ def initialize_moe_config(server_args: ServerArgs):
global DEEPEP_CONFIG
global IS_TBO_ENABLED
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
......@@ -115,6 +122,9 @@ def initialize_moe_config(server_args: ServerArgs):
DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
)
def get_moe_a2a_backend() -> MoeA2ABackend:
......@@ -175,3 +185,16 @@ def should_use_flashinfer_trtllm_moe():
>= pkg_version.parse("0.2.9rc1")
)
return result
@lru_cache(maxsize=1)
def should_use_flashinfer_cutlass_moe_fp4_allgather():
"""
Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving.
"""
return (
not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
and get_moe_runner_backend().is_flashinfer_cutlass()
and is_dp_attention_enabled()
and get_moe_expert_parallel_world_size() == get_attention_dp_size()
)
......@@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
from sglang.srt.layers.moe import (
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
......@@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output_dtype = x.dtype
x_sf = None
if should_use_flashinfer_cutlass_moe_fp4_allgather():
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
# Quantize before comm, swizzle after.
if x.shape[0] > 0:
x, x_sf = fp4_quantize(
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
)
else:
x_col = x.shape[1]
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
x_sf = torch.zeros(
0, x_col // 16, dtype=torch.uint8, device=x.device
)
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
)
x_sf = nvfp4_block_scale_interleave(x_sf)
output = flashinfer_cutlass_fused_moe(
x,
topk_ids.to(torch.int),
topk_weights,
layer.w13_weight.view(torch.long),
layer.w2_weight.view(torch.long),
x.dtype,
input=x,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
fc1_expert_weights=layer.w13_weight.view(torch.long),
fc2_expert_weights=layer.w2_weight.view(torch.long),
output_dtype=output_dtype,
input_sf=x_sf,
quant_scales=[
layer.w13_input_scale_quant,
layer.w13_blockscale_swizzled.view(torch.int32),
......@@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)[0]
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output
get_tp_group().reduce_scatterv(
global_output, output=output, sizes=get_dp_global_num_tokens()
)
return output
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
......
......@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"chunked_prefill_size",
"device",
"disable_chunked_prefix_cache",
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"enable_flashinfer_allreduce_fusion",
......
......@@ -649,7 +649,7 @@ class ForwardBatch:
num_tokens = global_num_tokens[0]
self.global_dp_buffer_len = buffer_len
set_dp_buffer_len(buffer_len, num_tokens)
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
bs = self.batch_size
......
......@@ -60,7 +60,11 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
......@@ -343,7 +347,7 @@ class DeepseekV2MoE(nn.Module):
self.shared_experts_weight_block_size = None
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
# disable tp for shared experts when enable deepep moe
# disable tp for shared experts when enable deepep moe, or with fp4 allgather
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
......@@ -354,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
else {}
),
)
......@@ -433,14 +438,19 @@ class DeepseekV2MoE(nn.Module):
if (
self.alt_stream is not None
and self.num_fused_shared_experts == 0
and hidden_states.shape[0] > 0
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
):
return self.forward_normal_dual_stream(
hidden_states, should_allreduce_fusion, use_reduce_scatter
hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
)
else:
return self.forward_normal(
hidden_states, should_allreduce_fusion, use_reduce_scatter
hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
)
else:
return self.forward_deepep(hidden_states, forward_batch)
......@@ -471,7 +481,12 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
......@@ -486,10 +501,14 @@ class DeepseekV2MoE(nn.Module):
):
return self.forward_cpu(hidden_states, should_allreduce_fusion)
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
if hidden_states.shape[0] > 0:
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
else:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda and not _use_aiter:
......@@ -501,7 +520,12 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
......@@ -580,11 +604,8 @@ class DeepseekV2MoE(nn.Module):
),
)
else:
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
hidden_states.device
)
final_hidden_states = self.experts(
......
......@@ -84,6 +84,7 @@ class _StageExecutor:
forward_batch: ForwardBatch = inputs["forward_batch"]
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
self._global_num_tokens = forward_batch.global_num_tokens_cpu
def next(self):
assert not self.done
......@@ -91,7 +92,11 @@ class _StageExecutor:
stage = self._stages[self._index]
if self._global_dp_buffer_len is not None:
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
set_dp_buffer_len(
self._global_dp_buffer_len,
self._local_dp_buffer_len,
self._global_num_tokens,
)
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
for op in stage:
......
......@@ -230,6 +230,7 @@ class ServerArgs:
enable_cudagraph_gc: bool = False
enable_nccl_nvls: bool = False
enable_symm_mem: bool = False
disable_flashinfer_cutlass_moe_fp4_allgather: bool = False
enable_tokenizer_batch_encode: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
......@@ -1714,6 +1715,11 @@ class ServerArgs:
action="store_true",
help="Enable NCCL symmetric memory for fast collectives.",
)
parser.add_argument(
"--disable-flashinfer-cutlass-moe-fp4-allgather",
action="store_true",
help="Disables quantize before all-gather for flashinfer cutlass moe.",
)
parser.add_argument(
"--enable-tokenizer-batch-encode",
action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment