"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "ba0cd35c95fca70de0254b5157497591f779d7ff"
Unverified Commit e1e318af authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Remove MoE DP chunking (#39107)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent f7e62e3d
...@@ -200,7 +200,14 @@ steps: ...@@ -200,7 +200,14 @@ steps:
timeout_in_minutes: 90 timeout_in_minutes: 90
device: h100 device: h100
num_devices: 2 num_devices: 2
optional: true source_file_dependencies:
- csrc/quantization/cutlass_w8a8/moe/
- csrc/moe/
- tests/kernels/moe
- vllm/model_executor/layers/fused_moe/
- vllm/model_executor/layers/quantization/
- vllm/distributed/device_communicators/
- vllm/config
commands: commands:
- pytest -v -s kernels/moe/test_moe_layer.py - pytest -v -s kernels/moe/test_moe_layer.py
...@@ -209,6 +216,13 @@ steps: ...@@ -209,6 +216,13 @@ steps:
timeout_in_minutes: 90 timeout_in_minutes: 90
device: b200 device: b200
num_devices: 2 num_devices: 2
optional: true source_file_dependencies:
- csrc/quantization/cutlass_w8a8/moe/
- csrc/moe/
- tests/kernels/moe
- vllm/model_executor/layers/fused_moe/
- vllm/model_executor/layers/quantization/
- vllm/distributed/device_communicators/
- vllm/config
commands: commands:
- pytest -v -s kernels/moe/test_moe_layer.py - pytest -v -s kernels/moe/test_moe_layer.py
...@@ -46,6 +46,7 @@ from vllm.utils.import_utils import ( ...@@ -46,6 +46,7 @@ from vllm.utils.import_utils import (
has_deep_gemm, has_deep_gemm,
has_mori, has_mori,
) )
from vllm.utils.math_utils import next_power_of_2
from .mk_objects import ( from .mk_objects import (
TestMoEQuantConfig, TestMoEQuantConfig,
...@@ -604,13 +605,6 @@ def make_modular_kernel( ...@@ -604,13 +605,6 @@ def make_modular_kernel(
vllm_config: VllmConfig, vllm_config: VllmConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEKernel: ) -> mk.FusedMoEKernel:
def next_power_of_2(x):
import math
if x == 0:
return 1
return 2 ** math.ceil(math.log2(x))
# make moe config # make moe config
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=get_tensor_model_parallel_world_size(), tp_size_=get_tensor_model_parallel_world_size(),
......
...@@ -126,7 +126,7 @@ def parallel_launch_with_config( ...@@ -126,7 +126,7 @@ def parallel_launch_with_config(
world_size: int, world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None], worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None],
vllm_config: VllmConfig, vllm_config: VllmConfig,
env_dict: dict[Any, Any], env_dict: dict[Any, Any] | None,
*args: P.args, *args: P.args,
**kwargs: P.kwargs, **kwargs: P.kwargs,
) -> None: ) -> None:
......
...@@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import ( ...@@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import (
is_deep_gemm_supported, is_deep_gemm_supported,
) )
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
...@@ -84,14 +85,6 @@ def with_dp_metadata(M: int, world_size: int): ...@@ -84,14 +85,6 @@ def with_dp_metadata(M: int, world_size: int):
yield yield
def next_power_of_2(x):
import math
if x == 0:
return 1
return 2 ** math.ceil(math.log2(x))
def make_block_quant_fp8_weights( def make_block_quant_fp8_weights(
e: int, e: int,
n: int, n: int,
......
...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8 from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
from vllm.model_executor.models.llama4 import Llama4MoE from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
try: try:
...@@ -174,6 +175,7 @@ class TestData: ...@@ -174,6 +175,7 @@ class TestData:
routing_method=layer.routing_method_type, routing_method=layer.routing_method_type,
activation=activation, activation=activation,
device=w13_quantized.device, device=w13_quantized.device,
max_num_tokens=next_power_of_2(m),
) )
return TestData( return TestData(
...@@ -348,6 +350,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -348,6 +350,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
in_dtype=torch.bfloat16, in_dtype=torch.bfloat16,
is_act_and_mul=activation.is_gated, is_act_and_mul=activation.is_gated,
routing_method=RoutingMethodType.TopK, routing_method=RoutingMethodType.TopK,
max_num_tokens=next_power_of_2(m),
) )
kernel = mk.FusedMoEKernel( kernel = mk.FusedMoEKernel(
......
...@@ -29,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( ...@@ -29,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability( if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
...@@ -105,6 +106,7 @@ def test_flashinfer_fp4_moe_no_graph( ...@@ -105,6 +106,7 @@ def test_flashinfer_fp4_moe_no_graph(
in_dtype=dtype, in_dtype=dtype,
is_act_and_mul=is_gated_act, is_act_and_mul=is_gated_act,
routing_method=RoutingMethodType.TopK, routing_method=RoutingMethodType.TopK,
max_num_tokens=next_power_of_2(m),
) )
flashinfer_experts = FusedMoEKernel( flashinfer_experts = FusedMoEKernel(
......
...@@ -59,6 +59,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_w ...@@ -59,6 +59,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_w
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
...@@ -1676,7 +1677,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend( ...@@ -1676,7 +1677,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
in_dtype=dtype, in_dtype=dtype,
is_act_and_mul=True, is_act_and_mul=True,
routing_method=RoutingMethodType.Renormalize, routing_method=RoutingMethodType.Renormalize,
max_num_tokens=m, max_num_tokens=next_power_of_2(m),
) )
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
......
...@@ -26,6 +26,7 @@ from tests.kernels.moe.utils import TestMLP, make_test_weights, moe_quantize_wei ...@@ -26,6 +26,7 @@ from tests.kernels.moe.utils import TestMLP, make_test_weights, moe_quantize_wei
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
ParallelConfig, ParallelConfig,
SchedulerConfig,
VllmConfig, VllmConfig,
set_current_vllm_config, set_current_vllm_config,
) )
...@@ -53,7 +54,7 @@ from vllm.utils.flashinfer import ( ...@@ -53,7 +54,7 @@ from vllm.utils.flashinfer import (
has_flashinfer_nvlink_two_sided, has_flashinfer_nvlink_two_sided,
) )
from vllm.utils.import_utils import has_deep_ep, has_mori, has_nixl_ep from vllm.utils.import_utils import has_deep_ep, has_mori, has_nixl_ep
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv, next_power_of_2
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import ( from vllm.v1.worker.workspace import (
init_workspace_manager, init_workspace_manager,
...@@ -65,8 +66,9 @@ fp8_dtype = torch.float8_e4m3fn # current_platform.fp8_dtype ...@@ -65,8 +66,9 @@ fp8_dtype = torch.float8_e4m3fn # current_platform.fp8_dtype
SHAPE_COMBOS = [ SHAPE_COMBOS = [
(1, 128, 256), (1, 128, 256),
(32, 1024, 512), (32, 1024, 512),
(222, 2048, 2048), # should be big enough to exercise DP chunking (222, 2048, 2048),
] ]
MAX_M = max([x[0] for x in SHAPE_COMBOS])
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6] TOP_KS = [2, 6]
...@@ -112,7 +114,7 @@ BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = { ...@@ -112,7 +114,7 @@ BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = {
"mori": {None, "fp8", "modelopt_fp8"}, "mori": {None, "fp8", "modelopt_fp8"},
"flashinfer_nvlink_two_sided": {None, "modelopt_fp8", "modelopt_fp4"}, "flashinfer_nvlink_two_sided": {None, "modelopt_fp8", "modelopt_fp4"},
"flashinfer_nvlink_one_sided": {None, "modelopt_fp8", "modelopt_fp4"}, "flashinfer_nvlink_one_sided": {None, "modelopt_fp8", "modelopt_fp4"},
"deepep_low_latency": {None, "fp8", "modelopt_fp8", "modelopt_fp4"}, "deepep_low_latency": {None, "modelopt_fp8", "modelopt_fp4"},
"deepep_high_throughput": {None, "fp8", "modelopt_fp8", "modelopt_fp4"}, "deepep_high_throughput": {None, "fp8", "modelopt_fp8", "modelopt_fp4"},
"nixl_ep": {None, "fp8", "modelopt_fp8"}, "nixl_ep": {None, "fp8", "modelopt_fp8"},
} }
...@@ -363,9 +365,9 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: ...@@ -363,9 +365,9 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]:
) )
# routed_input_transform + quantization + high hidden dimensions # routed_input_transform + quantization + high hidden dimensions
# TODO: Disable >= 2048 w/fp8 + deepep LL for now due to insane errors. # TODO: Disable >= 2048 for now due to insane errors.
if ( if (
(config.use_routed_input_transform or config.backend == "deepep_low_latency") config.use_routed_input_transform
and config.quantization is not None and config.quantization is not None
and config.k >= 2048 and config.k >= 2048
): ):
...@@ -1663,9 +1665,6 @@ def test_moe_layer( ...@@ -1663,9 +1665,6 @@ def test_moe_layer(
verbosity = pytestconfig.getoption("verbose") verbosity = pytestconfig.getoption("verbose")
test_env = dict()
test_env["VLLM_MOE_DP_CHUNK_SIZE"] = "128"
monkeypatch.setenv("VLLM_MOE_DP_CHUNK_SIZE", "128")
if os.environ.get("VLLM_LOGGING_LEVEL") is None: if os.environ.get("VLLM_LOGGING_LEVEL") is None:
monkeypatch.setenv("VLLM_LOGGING_LEVEL", "ERROR") monkeypatch.setenv("VLLM_LOGGING_LEVEL", "ERROR")
...@@ -1690,7 +1689,11 @@ def test_moe_layer( ...@@ -1690,7 +1689,11 @@ def test_moe_layer(
compilation_config.pass_config.fuse_allreduce_rms = False # for now compilation_config.pass_config.fuse_allreduce_rms = False # for now
vllm_config = VllmConfig( vllm_config = VllmConfig(
parallel_config=parallel_config, compilation_config=compilation_config parallel_config=parallel_config,
compilation_config=compilation_config,
scheduler_config=SchedulerConfig.default_factory(
max_num_batched_tokens=next_power_of_2(MAX_M)
),
) )
test_configs = generate_valid_test_configs( test_configs = generate_valid_test_configs(
...@@ -1718,7 +1721,7 @@ def test_moe_layer( ...@@ -1718,7 +1721,7 @@ def test_moe_layer(
world_size, world_size,
_parallel_worker, _parallel_worker,
vllm_config, vllm_config,
test_env, None,
test_configs, test_configs,
verbosity, verbosity,
) )
......
...@@ -69,6 +69,7 @@ def make_dummy_moe_config( ...@@ -69,6 +69,7 @@ def make_dummy_moe_config(
in_dtype=in_dtype, in_dtype=in_dtype,
device="cuda", device="cuda",
routing_method=RoutingMethodType.TopK, routing_method=RoutingMethodType.TopK,
max_num_tokens=512,
) )
......
...@@ -622,6 +622,18 @@ class ParallelConfig: ...@@ -622,6 +622,18 @@ class ParallelConfig:
and self.data_parallel_size > 1 and self.data_parallel_size > 1
) )
@property
def use_batched_dp_moe(self) -> bool:
return (
self.all2all_backend
in (
"deepep_low_latency",
"nixl_ep",
)
and self.enable_expert_parallel
and self.data_parallel_size > 1
)
@property @property
def node_rank_within_dp(self) -> int: def node_rank_within_dp(self) -> int:
return self.node_rank % self.nnodes_within_dp return self.node_rank % self.nnodes_within_dp
......
...@@ -40,6 +40,7 @@ class SchedulerConfig: ...@@ -40,6 +40,7 @@ class SchedulerConfig:
""" """
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048 DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP: ClassVar[int] = 256
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128 DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
runner_type: RunnerType = "generate" runner_type: RunnerType = "generate"
......
...@@ -1588,9 +1588,6 @@ class EngineArgs: ...@@ -1588,9 +1588,6 @@ class EngineArgs:
self._check_feature_supported() self._check_feature_supported()
self._set_default_chunked_prefill_and_prefix_caching_args(model_config) self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
self._set_default_max_num_seqs_and_batched_tokens_args(
usage_context, model_config
)
self._set_default_reasoning_config_args() self._set_default_reasoning_config_args()
sliding_window: int | None = None sliding_window: int | None = None
if not is_interleaved(model_config.hf_text_config): if not is_interleaved(model_config.hf_text_config):
...@@ -1846,6 +1843,12 @@ class EngineArgs: ...@@ -1846,6 +1843,12 @@ class EngineArgs:
target_parallel_config=parallel_config, target_parallel_config=parallel_config,
) )
self._set_default_max_num_seqs_and_batched_tokens_args(
usage_context,
model_config,
parallel_config,
)
assert self.max_num_batched_tokens is not None, ( assert self.max_num_batched_tokens is not None, (
"max_num_batched_tokens must be set by this point" "max_num_batched_tokens must be set by this point"
) )
...@@ -2244,6 +2247,7 @@ class EngineArgs: ...@@ -2244,6 +2247,7 @@ class EngineArgs:
self, self,
usage_context: UsageContext | None, usage_context: UsageContext | None,
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig,
): ):
world_size = self.pipeline_parallel_size * self.tensor_parallel_size world_size = self.pipeline_parallel_size * self.tensor_parallel_size
( (
...@@ -2255,10 +2259,15 @@ class EngineArgs: ...@@ -2255,10 +2259,15 @@ class EngineArgs:
orig_max_num_seqs = self.max_num_seqs orig_max_num_seqs = self.max_num_seqs
if self.max_num_batched_tokens is None: if self.max_num_batched_tokens is None:
self.max_num_batched_tokens = default_max_num_batched_tokens.get( if parallel_config.use_batched_dp_moe:
usage_context, self.max_num_batched_tokens = (
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP
) )
else:
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
if self.max_num_seqs is None: if self.max_num_seqs is None:
self.max_num_seqs = default_max_num_seqs.get( self.max_num_seqs = default_max_num_seqs.get(
......
...@@ -146,8 +146,6 @@ if TYPE_CHECKING: ...@@ -146,8 +146,6 @@ if TYPE_CHECKING:
VLLM_ENABLE_PREGRAD_PASSES: bool = False VLLM_ENABLE_PREGRAD_PASSES: bool = False
VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0 VLLM_DP_MASTER_PORT: int = 0
VLLM_MOE_DP_CHUNK_SIZE: int = 256
VLLM_ENABLE_MOE_DP_CHUNK: bool = True
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict"
VLLM_RAY_EXTRA_ENV_VAR_PREFIXES_TO_COPY: str = "" VLLM_RAY_EXTRA_ENV_VAR_PREFIXES_TO_COPY: str = ""
...@@ -1140,15 +1138,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1140,15 +1138,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DP_MASTER_IP": lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), "VLLM_DP_MASTER_IP": lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"),
# Port of the master node in the data parallel setting # Port of the master node in the data parallel setting
"VLLM_DP_MASTER_PORT": lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), "VLLM_DP_MASTER_PORT": lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
# In the context of executing MoE models with Data-Parallel, Expert-Parallel
# and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE
# dictates the quantum of tokens that can be dispatched from a DP
# rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE
# units.
"VLLM_MOE_DP_CHUNK_SIZE": lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")),
"VLLM_ENABLE_MOE_DP_CHUNK": lambda: bool(
int(os.getenv("VLLM_ENABLE_MOE_DP_CHUNK", "1"))
),
# Randomize inputs during dummy runs when using Data Parallel # Randomize inputs during dummy runs when using Data Parallel
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: ( "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: (
os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1" os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1"
......
...@@ -70,27 +70,8 @@ def _compute_sp_num_tokens( ...@@ -70,27 +70,8 @@ def _compute_sp_num_tokens(
return sp_tokens.tolist() return sp_tokens.tolist()
def _compute_chunked_local_num_tokens(
num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int,
max_num_tokens: int,
chunk_idx: int,
) -> list[int]:
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size)
sp_size = len(sp_tokens)
local_size = [-1] * sp_size
for i in range(sp_size):
# Take into account sharding if MoE activation is sequence parallel.
local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx))
if local_size[i] <= 0:
local_size[i] = 1 # ensure lockstep even if done
return local_size
@dataclass @dataclass
class DPMetadata: class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
num_tokens_across_dp_cpu: torch.Tensor num_tokens_across_dp_cpu: torch.Tensor
# NOTE: local_sizes should only be set by the chunked_sizes context manager # NOTE: local_sizes should only be set by the chunked_sizes context manager
...@@ -113,47 +94,7 @@ class DPMetadata: ...@@ -113,47 +94,7 @@ class DPMetadata:
assert num_tokens_across_dp_cpu[dp_rank] == batchsize, ( assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
) )
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) return DPMetadata(num_tokens_across_dp_cpu)
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
@contextmanager
def chunked_sizes(
self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int
):
"""
Context manager to compute and temporarily set the per-rank local token
sizes for a specific chunk during chunked forward execution.
This is necessary to ensure each DP (data parallel) rank processes its
designated portion of tokens in lockstep with others, even when the
token counts are uneven or some ranks have completed their input early.
For chunked execution, we break up the total tokens on each rank into
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
`chunk_idx`, this context manager sets `self.local_sizes` to the number
of tokens to process in that chunk on each rank.
`self.local_sizes` is only valid inside the context.
Args:
sequence_parallel_size: When Attn is TP and MoE layers are EP,
we use SP between the layers to avoid
redundant ops. We need this value to
compute the chunked sizes.
max_chunk_size_per_rank: The max number of tokens each rank is
allowed to process in this chunk.
chunk_idx: The index of the chunk to compute sizes for.
"""
self.local_sizes = _compute_chunked_local_num_tokens(
self.num_tokens_across_dp_cpu,
sequence_parallel_size,
max_chunk_size_per_rank,
chunk_idx,
)
try:
yield self.local_sizes
finally:
self.local_sizes = None
@contextmanager @contextmanager
def sp_local_sizes(self, sequence_parallel_size: int): def sp_local_sizes(self, sequence_parallel_size: int):
......
...@@ -6,8 +6,7 @@ from typing import Union ...@@ -6,8 +6,7 @@ from typing import Union
import torch import torch
import vllm.envs as envs from vllm.config import ParallelConfig, SchedulerConfig
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
...@@ -937,15 +936,6 @@ class FusedMoEParallelConfig: ...@@ -937,15 +936,6 @@ class FusedMoEParallelConfig:
all2all_backend: str # all2all backend for MoE communication all2all_backend: str # all2all backend for MoE communication
enable_eplb: bool # whether to enable expert load balancing enable_eplb: bool # whether to enable expert load balancing
@property
def use_dp_chunking(self) -> bool:
return (
self.use_deepep_ll_kernels
or self.use_mori_kernels
or self.use_fi_nvl_two_sided_kernels
or self.use_nixl_ep_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property @property
def is_sequence_parallel(self) -> bool: def is_sequence_parallel(self) -> bool:
return self.sp_size > 1 return self.sp_size > 1
...@@ -1184,7 +1174,7 @@ class FusedMoEConfig: ...@@ -1184,7 +1174,7 @@ class FusedMoEConfig:
intermediate_size_per_partition_unpadded: int | None = None intermediate_size_per_partition_unpadded: int | None = None
moe_backend: str = "auto" moe_backend: str = "auto"
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP
has_bias: bool = False has_bias: bool = False
is_act_and_mul: bool = True is_act_and_mul: bool = True
is_lora_enabled: bool = False is_lora_enabled: bool = False
......
...@@ -8,7 +8,6 @@ from typing import Literal, cast, get_args, overload ...@@ -8,7 +8,6 @@ from typing import Literal, cast, get_args, overload
import torch import torch
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.parallel import ExpertPlacementStrategy
...@@ -479,7 +478,7 @@ class FusedMoE(CustomOp): ...@@ -479,7 +478,7 @@ class FusedMoE(CustomOp):
in_dtype=moe_in_dtype, in_dtype=moe_in_dtype,
moe_backend=vllm_config.kernel_config.moe_backend, moe_backend=vllm_config.kernel_config.moe_backend,
router_logits_dtype=router_logits_dtype, router_logits_dtype=router_logits_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens,
has_bias=has_bias, has_bias=has_bias,
is_act_and_mul=is_act_and_mul, is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None, is_lora_enabled=vllm_config.lora_config is not None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.forward_context import (
get_forward_context,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
from vllm.v1.worker.workspace import current_workspace_manager
class ChunkingMoERunner(MoERunnerBase):
"""
MoE runner wrapper that adds chunked processing to any MoERunnerBase.
This runner wraps an inner MoERunnerBase and overrides _forward_impl to
process large batches by breaking them into smaller chunks. Each chunk
is delegated to the inner runner's _forward_impl, making chunking
composable with any runner implementation.
All MoERunnerBase state (moe_config, router, quant_method, etc.) is
transparently delegated to the inner runner via __getattr__.
ChunkingMoERunner only owns chunking-specific state: the pre-allocated
workspace buffers and the reduce_results override.
Key behaviors:
- Pre-allocates workspace tensors for CUDA graph compatibility
- Processes chunks via inner._forward_impl per chunk
- Never reduces results (reduce_results always returns False)
"""
def __init__(self, inner: MoERunnerBase):
# Assert that _maybe_dispatch/_maybe_combine will be nops.
assert inner.moe_config.pcp_size == 1
# Skip MoERunnerBase.__init__ — all state is delegated to inner
# via __getattr__. Only chunking-specific state lives here.
self._inner = inner
# Pre-allocated staging buffers. These need to exist ahead of time
# due to CUDA graph construction needing fixed buffer addresses.
self.batched_hidden_states, self.batched_router_logits = (
self._init_dp_chunking()
)
def __getattr__(self, name):
# Delegate attribute access to the inner runner. This is only
# called when normal lookup (instance __dict__, class MRO) fails,
# so ChunkingMoERunner's own attributes and methods take priority.
return getattr(self._inner, name)
@property
def shared_experts(self) -> SharedExperts | None:
return self._inner.shared_experts
# TODO(bnell): temporary hack, do not call this method.
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
self._inner._replace_quant_method(quant_method)
self.quant_method = quant_method
def is_internal_router(self) -> bool:
return self._inner.gate is not None
# Reducing results when chunking is handled by the MK finalize operations
# when DP chunking is enabled..
# This will be removed by #35949
@property
def reduce_results(self) -> bool:
return False
def _init_dp_chunking(self) -> list[torch.Tensor]:
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
moe = self.moe_config
if self.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
else:
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)
# Does this need some kind of profiling run check like modular_kernel.py?
return current_workspace_manager().get_simultaneous(
(states_shape, moe.in_dtype),
(logits_shape, moe.router_logits_dtype),
)
def _allocate_dp_chunking_outputs(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> tuple[torch.Tensor | None, torch.Tensor]:
# Assert the inputs are of the proper type and shape.
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
)
# Check size compatibility.
assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == router_logits.size(-1)
final_fused_hidden_states = torch.empty_like(hidden_states)
if self.shared_experts is not None:
if shared_experts_input is not None:
final_shared_hidden_states = torch.empty_like(shared_experts_input)
else:
final_shared_hidden_states = torch.empty_like(hidden_states)
else:
final_shared_hidden_states = None
return final_shared_hidden_states, final_fused_hidden_states
def _slice_and_copy_input(
self,
out_slice: torch.Tensor,
orig: torch.Tensor | None,
start: int,
end: int,
) -> torch.Tensor:
assert orig is not None
slice_size = end - start
orig_slice = orig[start:end, :]
if self.enable_dbo:
assert out_slice.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
out_slice = out_slice[batch_buffer_idx, :]
assert out_slice.size(0) >= slice_size
out_slice = out_slice[:slice_size, :]
out_slice.copy_(orig_slice, non_blocking=True)
return out_slice
def _forward_impl(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
final_shared_hidden_states, final_fused_hidden_states = (
self._allocate_dp_chunking_outputs(
hidden_states, router_logits, shared_experts_input
)
)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.moe_config.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(
max_tokens_across_dispatchers, self.moe_config.sp_size
)
num_tokens = hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
):
chunk_start = chunk_start_
chunk_end = min(
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
chunk_sizes = ctx.dp_metadata.chunked_sizes(
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
)
with chunk_sizes:
hidden_states_chunk = self._slice_and_copy_input(
self.batched_hidden_states,
hidden_states,
chunk_start,
chunk_end,
)
router_logits_chunk = self._slice_and_copy_input(
self.batched_router_logits,
router_logits,
chunk_start,
chunk_end,
)
shared_experts_input_chunk = (
shared_experts_input[chunk_start:chunk_end, :]
if shared_experts_input is not None
else None
)
# Delegate per-chunk computation to the inner runner.
chunk_result = self._inner._forward_impl(
layer=layer,
hidden_states=hidden_states_chunk,
router_logits=router_logits_chunk,
shared_experts_input=shared_experts_input_chunk,
)
# Store outputs
# TODO(bnell): document when chunk_start >= num_tokens
if chunk_start < num_tokens:
if self.shared_experts is not None:
assert isinstance(chunk_result, tuple)
shared_output_chunk, hidden_states_chunk = chunk_result
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
hidden_states_chunk, non_blocking=True
)
assert shared_output_chunk is not None
assert final_shared_hidden_states is not None
final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
shared_output_chunk, non_blocking=True
)
else:
assert isinstance(chunk_result, torch.Tensor)
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
chunk_result, non_blocking=True
)
if self.shared_experts is None:
return final_fused_hidden_states
else:
assert final_shared_hidden_states is not None
return (final_shared_hidden_states, final_fused_hidden_states)
...@@ -12,9 +12,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( ...@@ -12,9 +12,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter, FusedMoERouter,
) )
from vllm.model_executor.layers.fused_moe.runner.chunking_moe_runner import (
ChunkingMoERunner,
)
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import ( from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
DefaultMoERunner, DefaultMoERunner,
) )
...@@ -35,7 +32,7 @@ def create_moe_runner( ...@@ -35,7 +32,7 @@ def create_moe_runner(
reduce_results: bool, reduce_results: bool,
enable_dbo: bool, enable_dbo: bool,
) -> MoERunner: ) -> MoERunner:
runner = DefaultMoERunner( return DefaultMoERunner(
layer_name, layer_name,
moe_config, moe_config,
router, router,
...@@ -46,6 +43,3 @@ def create_moe_runner( ...@@ -46,6 +43,3 @@ def create_moe_runner(
reduce_results, reduce_results,
enable_dbo, enable_dbo,
) )
if moe_config.moe_parallel_config.use_dp_chunking:
return ChunkingMoERunner(runner)
return runner
...@@ -69,7 +69,6 @@ class SharedExperts: ...@@ -69,7 +69,6 @@ class SharedExperts:
self._moe_config = moe_config self._moe_config = moe_config
self._quant_method = quant_method self._quant_method = quant_method
self._reduce_results = reduce_results self._reduce_results = reduce_results
self._use_dp_chunking = moe_config.moe_parallel_config.use_dp_chunking
# Allow disabling of the separate shared experts stream for # Allow disabling of the separate shared experts stream for
# debug purposes. # debug purposes.
...@@ -87,20 +86,6 @@ class SharedExperts: ...@@ -87,20 +86,6 @@ class SharedExperts:
"Enabled separate cuda stream for MoE shared_experts", scope="local" "Enabled separate cuda stream for MoE shared_experts", scope="local"
) )
@property
def _use_external_experts(self) -> bool:
if self._use_dp_chunking:
return False
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
backend = self._moe_config.moe_parallel_config.all2all_backend
return (
self._moe_config.moe_parallel_config.enable_eplb
and backend != "allgather_reducescatter"
) or self._moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels
def _determine_shared_experts_order( def _determine_shared_experts_order(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -110,7 +95,6 @@ class SharedExperts: ...@@ -110,7 +95,6 @@ class SharedExperts:
should_run_shared_in_aux_stream = ( should_run_shared_in_aux_stream = (
current_platform.is_cuda() current_platform.is_cuda()
and not self._use_dp_chunking
and self._stream is not None and self._stream is not None
and hidden_states.shape[0] and hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
......
...@@ -1502,9 +1502,9 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod): ...@@ -1502,9 +1502,9 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False) layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size # FIXME warp need to be adjusted based on batch size
# only apply to batched mode # only apply to batched mode
if self.moe.use_ep: if self.moe.use_ep:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 num_warps = 4 if self.moe.max_num_tokens <= 512 else 8
else: else:
num_warps = 8 num_warps = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment