Unverified Commit 2cdf9222 authored by Xinan Miao's avatar Xinan Miao Committed by GitHub
Browse files

[Feature]: Remove Chunking From FusedMoE (#34086)


Signed-off-by: default avatarSouthWest7 <am1ao@qq.com>
Signed-off-by: default avatarSouthwest <1403572259@qq.com>
Signed-off-by: default avatarsouthwest <am1ao@qq.com>
Signed-off-by: default avatarXinan Miao <1403572259@qq.com>
Co-authored-by: default avatarSouthWest7 <am1ao@qq.com>
parent c973ecde
...@@ -167,9 +167,6 @@ FusedMoEExpertsModular performs the core of the FusedMoE operations. The various ...@@ -167,9 +167,6 @@ FusedMoEExpertsModular performs the core of the FusedMoE operations. The various
`FusedMoEExpertsModular::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format. `FusedMoEExpertsModular::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
`FusedMoEExpertsModular::supports_chunking()`: Return True if the implementation supports chunking. Typically
implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not.
`FusedMoEExpertsModular::supports_expert_map()`: Return True if the implementation supports expert map. `FusedMoEExpertsModular::supports_expert_map()`: Return True if the implementation supports expert map.
`FusedMoEExpertsModular::workspace_shapes()` / `FusedMoEExpertsModular::workspace_shapes()` /
...@@ -220,8 +217,8 @@ If you are adding some `FusedMoEPrepareAndFinalizeModular` / `FusedMoEExpertsMod ...@@ -220,8 +217,8 @@ If you are adding some `FusedMoEPrepareAndFinalizeModular` / `FusedMoEExpertsMod
1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively. 1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively.
2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`, 2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`,
`Config::is_fe_16bit_supported()`, `Config::is_fe_fp8_supported()`, `Config::is_fe_block_fp8_supported()`, `Config::is_fe_16bit_supported()`, `Config::is_fe_fp8_supported()`, `Config::is_fe_block_fp8_supported()`
`Config::is_fe_supports_chunking()` methods in [/tests/kernels/moe/modular_kernel_tools/common.py](../../tests/kernels/moe/modular_kernel_tools/common.py) methods in [/tests/kernels/moe/modular_kernel_tools/common.py](../../tests/kernels/moe/modular_kernel_tools/common.py)
Doing this will add the new implementation to the test suite. Doing this will add the new implementation to the test suite.
......
...@@ -82,11 +82,6 @@ def make_config_arg_parser(description: str): ...@@ -82,11 +82,6 @@ def make_config_arg_parser(description: str):
"--num-experts", type=int, default=32, help="Global num experts" "--num-experts", type=int, default=32, help="Global num experts"
) )
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk") parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
parser.add_argument(
"--fused-moe-chunk-size",
type=int,
help="Fused moe chunk size used for the non-batched fused experts impl.",
)
# Quant args # Quant args
parser.add_argument( parser.add_argument(
...@@ -158,7 +153,6 @@ def make_config(args: argparse.Namespace) -> Config: ...@@ -158,7 +153,6 @@ def make_config(args: argparse.Namespace) -> Config:
quant_config=quant_config, quant_config=quant_config,
prepare_finalize_type=args.pf_type, prepare_finalize_type=args.pf_type,
fused_experts_type=args.experts_type, fused_experts_type=args.experts_type,
fused_moe_chunk_size=args.fused_moe_chunk_size,
world_size=args.world_size, world_size=args.world_size,
torch_trace_dir_path=args.torch_trace_dir_path, torch_trace_dir_path=args.torch_trace_dir_path,
) )
...@@ -68,7 +68,6 @@ class Config: ...@@ -68,7 +68,6 @@ class Config:
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEExperts fused_experts_type: mk.FusedMoEExperts
fused_moe_chunk_size: int | None
world_size: int world_size: int
torch_trace_dir_path: str | None = None torch_trace_dir_path: str | None = None
...@@ -89,7 +88,6 @@ class Config: ...@@ -89,7 +88,6 @@ class Config:
s += f" K={self.K}\n" s += f" K={self.K}\n"
s += f" topk={self.topks}\n" s += f" topk={self.topks}\n"
s += f" dtype={self.dtype}\n" s += f" dtype={self.dtype}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
s += " Quant:\n" s += " Quant:\n"
if self.quant_config is not None: if self.quant_config is not None:
s += f" q_dtype={self.quant_dtype}\n" s += f" q_dtype={self.quant_dtype}\n"
...@@ -152,11 +150,6 @@ class Config: ...@@ -152,11 +150,6 @@ class Config:
vllm_config.parallel_config.all2all_backend = self.all2all_backend() vllm_config.parallel_config.all2all_backend = self.all2all_backend()
if self.fused_moe_chunk_size is not None:
env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
)
return vllm_config, env_dict return vllm_config, env_dict
def is_fp8_block_quantized(self): def is_fp8_block_quantized(self):
...@@ -189,10 +182,6 @@ class Config: ...@@ -189,10 +182,6 @@ class Config:
info = expert_info(self.fused_experts_type) info = expert_info(self.fused_experts_type)
return info.blocked_quantization_support return info.blocked_quantization_support
def is_fe_supports_chunking(self):
info = expert_info(self.fused_experts_type)
return info.supports_chunking
def supports_expert_map(self): def supports_expert_map(self):
info = expert_info(self.fused_experts_type) info = expert_info(self.fused_experts_type)
return info.supports_expert_map return info.supports_expert_map
...@@ -233,10 +222,6 @@ class Config: ...@@ -233,10 +222,6 @@ class Config:
if not self.is_standard_fused_experts(): if not self.is_standard_fused_experts():
return False, "Mismatched format." return False, "Mismatched format."
use_chunking = self.fused_moe_chunk_size is not None
if use_chunking and not self.is_fe_supports_chunking():
return False, "Chunking not supported."
# Check quantization sanity # Check quantization sanity
if ( if (
int(self.is_per_act_token_quant) int(self.is_per_act_token_quant)
......
...@@ -42,12 +42,6 @@ def rank_worker( ...@@ -42,12 +42,6 @@ def rank_worker(
): ):
set_random_seed(pgi.rank) set_random_seed(pgi.rank)
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device # get weights to this device
weights.to_current_device() weights.to_current_device()
...@@ -135,7 +129,6 @@ def make_feature_matrix(csv_file_path: str): ...@@ -135,7 +129,6 @@ def make_feature_matrix(csv_file_path: str):
fused_experts_type=experts_type, fused_experts_type=experts_type,
quant_config=quant_config, quant_config=quant_config,
world_size=2, world_size=2,
fused_moe_chunk_size=None,
) )
success = None success = None
......
...@@ -64,7 +64,6 @@ class ExpertInfo: ...@@ -64,7 +64,6 @@ class ExpertInfo:
activation_format: mk.FusedMoEActivationFormat activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[torch.dtype | str] supported_dtypes: list[torch.dtype | str]
blocked_quantization_support: bool blocked_quantization_support: bool
supports_chunking: bool
supports_expert_map: bool supports_expert_map: bool
needs_matching_quant: bool = False needs_matching_quant: bool = False
needs_deep_gemm: bool = False needs_deep_gemm: bool = False
...@@ -127,7 +126,6 @@ def register_experts( ...@@ -127,7 +126,6 @@ def register_experts(
activation_format: mk.FusedMoEActivationFormat, activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[torch.dtype | str], supported_dtypes: list[torch.dtype | str],
blocked_quantization_support: bool, blocked_quantization_support: bool,
supports_chunking: bool,
supports_expert_map: bool, supports_expert_map: bool,
needs_matching_quant: bool = False, needs_matching_quant: bool = False,
needs_deep_gemm: bool = False, needs_deep_gemm: bool = False,
...@@ -141,7 +139,6 @@ def register_experts( ...@@ -141,7 +139,6 @@ def register_experts(
activation_format, activation_format,
supported_dtypes, supported_dtypes,
blocked_quantization_support, blocked_quantization_support,
supports_chunking,
supports_expert_map, supports_expert_map,
needs_matching_quant, needs_matching_quant,
needs_deep_gemm, needs_deep_gemm,
...@@ -176,7 +173,6 @@ register_experts( ...@@ -176,7 +173,6 @@ register_experts(
batched_format, batched_format,
common_float_types, common_float_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False, supports_expert_map=False,
needs_matching_quant=True, needs_matching_quant=True,
) )
...@@ -186,7 +182,6 @@ register_experts( ...@@ -186,7 +182,6 @@ register_experts(
standard_format, standard_format,
common_float_and_int_types, common_float_and_int_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True, supports_expert_map=True,
needs_matching_quant=True, needs_matching_quant=True,
) )
...@@ -196,7 +191,6 @@ register_experts( ...@@ -196,7 +191,6 @@ register_experts(
batched_format, batched_format,
common_float_and_int_types, common_float_and_int_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=True, supports_expert_map=True,
) )
...@@ -262,7 +256,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability ...@@ -262,7 +256,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
standard_format, standard_format,
nvfp4_types + fp8_types, nvfp4_types + fp8_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now # Note: this is a hack to get it to run for now
supports_expert_map=True, supports_expert_map=True,
) )
...@@ -281,7 +274,6 @@ if has_aiter(): ...@@ -281,7 +274,6 @@ if has_aiter():
standard_format, standard_format,
fp8_types, fp8_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True, supports_expert_map=True,
needs_aiter=True, needs_aiter=True,
) )
...@@ -294,7 +286,6 @@ if has_deep_gemm() and is_deep_gemm_supported(): ...@@ -294,7 +286,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
batched_format, batched_format,
fp8_types, fp8_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False, supports_expert_map=False,
needs_matching_quant=False, needs_matching_quant=False,
needs_deep_gemm=True, needs_deep_gemm=True,
...@@ -304,7 +295,6 @@ if has_deep_gemm() and is_deep_gemm_supported(): ...@@ -304,7 +295,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
standard_format, standard_format,
fp8_types, fp8_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True, supports_expert_map=True,
needs_matching_quant=False, needs_matching_quant=False,
needs_deep_gemm=True, needs_deep_gemm=True,
...@@ -314,7 +304,6 @@ if has_deep_gemm() and is_deep_gemm_supported(): ...@@ -314,7 +304,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
standard_format, standard_format,
common_float_and_int_types, common_float_and_int_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True, supports_expert_map=True,
needs_matching_quant=True, needs_matching_quant=True,
needs_deep_gemm=True, needs_deep_gemm=True,
...@@ -331,7 +320,6 @@ if cutlass_fp8_supported(): ...@@ -331,7 +320,6 @@ if cutlass_fp8_supported():
standard_format, standard_format,
fp8_types, fp8_types,
blocked_quantization_support=False, blocked_quantization_support=False,
supports_chunking=True,
supports_expert_map=False, supports_expert_map=False,
) )
register_experts( register_experts(
...@@ -339,7 +327,6 @@ if cutlass_fp8_supported(): ...@@ -339,7 +327,6 @@ if cutlass_fp8_supported():
batched_format, batched_format,
fp8_types, fp8_types,
blocked_quantization_support=False, blocked_quantization_support=False,
supports_chunking=False,
supports_expert_map=False, supports_expert_map=False,
) )
else: else:
...@@ -354,7 +341,6 @@ if cutlass_fp4_supported(): ...@@ -354,7 +341,6 @@ if cutlass_fp4_supported():
standard_format, standard_format,
nvfp4_types, nvfp4_types,
blocked_quantization_support=True, blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False, supports_expert_map=False,
) )
else: else:
......
...@@ -85,12 +85,6 @@ def rank_worker( ...@@ -85,12 +85,6 @@ def rank_worker(
): ):
set_random_seed(pgi.rank) set_random_seed(pgi.rank)
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device # get weights to this device
weights.to_current_device() weights.to_current_device()
......
...@@ -158,8 +158,6 @@ def test_w8a8_block_fp8_fused_moe( ...@@ -158,8 +158,6 @@ def test_w8a8_block_fp8_fused_moe(
torch.manual_seed(seed) torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
...@@ -226,11 +224,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) ...@@ -226,11 +224,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
if not _valid_deep_gemm_shape(M, N, K): if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
chunk_size = 1024
torch.manual_seed(seed) torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_size = get_mk_alignment_for_contiguous_layout() block_size = get_mk_alignment_for_contiguous_layout()
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -252,9 +247,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) ...@@ -252,9 +247,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
# setup code in case we are able to revisit this later. # setup code in case we are able to revisit this later.
use_compile = False use_compile = False
use_cudagraph = ( use_cudagraph = N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
......
...@@ -321,7 +321,6 @@ def test_cutlass_moe_8_bit_no_graph( ...@@ -321,7 +321,6 @@ def test_cutlass_moe_8_bit_no_graph(
ep_size: int | None = None, ep_size: int | None = None,
): ):
set_random_seed(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
...@@ -376,7 +375,6 @@ def test_cutlass_moe_8_bit_cuda_graph( ...@@ -376,7 +375,6 @@ def test_cutlass_moe_8_bit_cuda_graph(
workspace_init, workspace_init,
): ):
set_random_seed(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
dtype = torch.half dtype = torch.half
......
...@@ -204,7 +204,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -204,7 +204,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
if not current_platform.has_device_capability(100): if not current_platform.has_device_capability(100):
pytest.skip("Test is only supported for sm >= 100") pytest.skip("Test is only supported for sm >= 100")
set_random_seed(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit( td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=True, activation=activation m, k, n, e, is_trtllm=True, activation=activation
...@@ -289,7 +288,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -289,7 +288,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
workspace_init, workspace_init,
): ):
set_random_seed(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit( td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=False, activation=activation m, k, n, e, is_trtllm=False, activation=activation
......
...@@ -84,12 +84,6 @@ def rank_worker( ...@@ -84,12 +84,6 @@ def rank_worker(
set_random_seed(pgi.rank) set_random_seed(pgi.rank)
# sanity check
from vllm import envs
if base_config.fused_moe_chunk_size is not None:
assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device # get weights to this device
weights.to_current_device() weights.to_current_device()
...@@ -162,7 +156,6 @@ Ns = [1024] ...@@ -162,7 +156,6 @@ Ns = [1024]
TOPKs = [4, 1] TOPKs = [4, 1]
Es = [32] Es = [32]
DTYPEs = [torch.bfloat16] DTYPEs = [torch.bfloat16]
FUSED_MOE_CHUNK_SIZES = [None, 16]
def is_nyi_config(config: Config) -> bool: def is_nyi_config(config: Config) -> bool:
...@@ -185,14 +178,13 @@ def generate_valid_test_cases( ...@@ -185,14 +178,13 @@ def generate_valid_test_cases(
cases = [] cases = []
total = 0 total = 0
for k, n, e, dtype, quant_config, combination, chunk_size in product( for k, n, e, dtype, quant_config, combination in product(
Ks, Ks,
Ns, Ns,
Es, Es,
DTYPEs, DTYPEs,
MK_QUANT_CONFIGS, MK_QUANT_CONFIGS,
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES), product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
FUSED_MOE_CHUNK_SIZES,
): ):
total = total + 1 total = total + 1
...@@ -206,7 +198,6 @@ def generate_valid_test_cases( ...@@ -206,7 +198,6 @@ def generate_valid_test_cases(
quant_config=quant_config, quant_config=quant_config,
prepare_finalize_type=combination[0], prepare_finalize_type=combination[0],
fused_experts_type=combination[1], fused_experts_type=combination[1],
fused_moe_chunk_size=chunk_size,
world_size=world_size, world_size=world_size,
) )
...@@ -234,7 +225,6 @@ def generate_valid_test_cases( ...@@ -234,7 +225,6 @@ def generate_valid_test_cases(
quant_config, quant_config,
combination[0], combination[0],
combination[1], combination[1],
chunk_size,
world_size, world_size,
) )
) )
...@@ -245,7 +235,7 @@ def generate_valid_test_cases( ...@@ -245,7 +235,7 @@ def generate_valid_test_cases(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size", "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
generate_valid_test_cases( generate_valid_test_cases(
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
), ),
...@@ -259,7 +249,6 @@ def test_modular_kernel_combinations_multigpu( ...@@ -259,7 +249,6 @@ def test_modular_kernel_combinations_multigpu(
quant_config: TestMoEQuantConfig | None, quant_config: TestMoEQuantConfig | None,
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEExperts, fused_experts_type: mk.FusedMoEExperts,
chunk_size: int | None,
world_size: int, world_size: int,
pytestconfig, pytestconfig,
): ):
...@@ -280,7 +269,6 @@ def test_modular_kernel_combinations_multigpu( ...@@ -280,7 +269,6 @@ def test_modular_kernel_combinations_multigpu(
quant_config=quant_config, quant_config=quant_config,
prepare_finalize_type=prepare_finalize_type, prepare_finalize_type=prepare_finalize_type,
fused_experts_type=fused_experts_type, fused_experts_type=fused_experts_type,
fused_moe_chunk_size=chunk_size,
world_size=world_size, world_size=world_size,
) )
verbosity = pytestconfig.getoption("verbose") verbosity = pytestconfig.getoption("verbose")
...@@ -288,7 +276,7 @@ def test_modular_kernel_combinations_multigpu( ...@@ -288,7 +276,7 @@ def test_modular_kernel_combinations_multigpu(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size", "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
generate_valid_test_cases( generate_valid_test_cases(
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
), ),
...@@ -301,7 +289,6 @@ def test_modular_kernel_combinations_singlegpu( ...@@ -301,7 +289,6 @@ def test_modular_kernel_combinations_singlegpu(
quant_config: TestMoEQuantConfig | None, quant_config: TestMoEQuantConfig | None,
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEExperts, fused_experts_type: mk.FusedMoEExperts,
chunk_size: int | None,
world_size: int, world_size: int,
pytestconfig, pytestconfig,
workspace_init, workspace_init,
...@@ -318,7 +305,6 @@ def test_modular_kernel_combinations_singlegpu( ...@@ -318,7 +305,6 @@ def test_modular_kernel_combinations_singlegpu(
quant_config=quant_config, quant_config=quant_config,
prepare_finalize_type=prepare_finalize_type, prepare_finalize_type=prepare_finalize_type,
fused_experts_type=fused_experts_type, fused_experts_type=fused_experts_type,
fused_moe_chunk_size=chunk_size,
world_size=world_size, world_size=world_size,
) )
......
...@@ -287,7 +287,6 @@ def run_moe_test( ...@@ -287,7 +287,6 @@ def run_moe_test(
@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize("chunk_size", [8192])
def test_fused_moe( def test_fused_moe(
m: int, m: int,
n: int, n: int,
...@@ -297,14 +296,11 @@ def test_fused_moe( ...@@ -297,14 +296,11 @@ def test_fused_moe(
ep_size: int, ep_size: int,
dtype: torch.dtype, dtype: torch.dtype,
padding: bool, padding: bool,
chunk_size: int,
monkeypatch, monkeypatch,
workspace_init, workspace_init,
): ):
set_random_seed(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
# #
# Setup test data # Setup test data
# #
...@@ -398,12 +394,12 @@ def test_fused_moe( ...@@ -398,12 +394,12 @@ def test_fused_moe(
) )
def test_fused_moe_int64_overflow(monkeypatch, workspace_init): def test_fused_moe_int64_overflow(workspace_init):
"""Regression test for int32 overflow in stride*offset products. """Regression test for int32 overflow in stride*offset products.
When chunking is disabled and M is large, stride_cm * offs_token can With large M, stride_cm * offs_token can exceed int32 max. Verifies
exceed int32 max. Verifies the offs_token int64 cast (fix for #34413) the offs_token int64 cast (fix for #34413) prevents overflow and
prevents overflow and produces correct results. produces correct results.
Reproduces the scenario from PR #34279. Reproduces the scenario from PR #34279.
""" """
...@@ -417,9 +413,6 @@ def test_fused_moe_int64_overflow(monkeypatch, workspace_init): ...@@ -417,9 +413,6 @@ def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
m, n, k, e, topk = 100000, 2048, 1024, 8, 6 m, n, k, e, topk = 100000, 2048, 1024, 8, 6
dtype = torch.bfloat16 dtype = torch.bfloat16
# Disable chunking to expose the overflow-prone code path
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "10000000")
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
...@@ -452,7 +445,6 @@ def test_fused_moe_int64_overflow(monkeypatch, workspace_init): ...@@ -452,7 +445,6 @@ def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
@pytest.mark.parametrize("topk", TOP_KS_SMALL) @pytest.mark.parametrize("topk", TOP_KS_SMALL)
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize("chunk_size", [8192])
def test_naive_block_assignment_moe( def test_naive_block_assignment_moe(
m: int, m: int,
n: int, n: int,
...@@ -461,14 +453,11 @@ def test_naive_block_assignment_moe( ...@@ -461,14 +453,11 @@ def test_naive_block_assignment_moe(
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
padding: bool, padding: bool,
chunk_size: int,
monkeypatch, monkeypatch,
workspace_init, workspace_init,
): ):
set_random_seed(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
# #
# Setup test data # Setup test data
# #
......
...@@ -53,8 +53,6 @@ if TYPE_CHECKING: ...@@ -53,8 +53,6 @@ if TYPE_CHECKING:
VLLM_CPU_SGL_KERNEL: bool = False VLLM_CPU_SGL_KERNEL: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_FUSED_MOE_CHUNK_SIZE: int = 16 * 1024
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
...@@ -822,15 +820,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -822,15 +820,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
), ),
# Enable SPMD mode for TPU backend. # Enable SPMD mode for TPU backend.
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), "VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
"VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(
os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024))
),
# Control whether to use fused MoE activation chunking. Current chunking
# logic is incompatible with torch.compile and causes IMA. See issue
# https://github.com/vllm-project/vllm/issues/19631.
"VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": lambda: bool(
int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1"))
),
# If set, the OpenAI API server will stay alive even after the underlying # If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests # AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool( "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool(
......
...@@ -190,9 +190,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -190,9 +190,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
) )
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE) M = num_tokens
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2] max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs( shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w13", op_prefix="w13",
...@@ -281,9 +280,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -281,9 +280,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
) )
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE) M = num_tokens
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2] max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs( shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2", op_prefix="w2",
......
...@@ -311,9 +311,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular): ...@@ -311,9 +311,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
......
...@@ -400,9 +400,6 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -400,9 +400,6 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
or moe_parallel_config.use_deepep_ht_kernels or moe_parallel_config.use_deepep_ht_kernels
) )
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
...@@ -445,9 +442,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -445,9 +442,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts return mk.FusedMoEActivationFormat.BatchedExperts
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
...@@ -713,9 +707,6 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular): ...@@ -713,9 +707,6 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
def supports_chunking(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP() return TopKWeightAndReduceNoOP()
...@@ -998,9 +989,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular): ...@@ -998,9 +989,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
"This method should not be called." "This method should not be called."
) )
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
......
...@@ -154,9 +154,6 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular): ...@@ -154,9 +154,6 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
# NOTE(rob): discovered an IMA with this combination. Needs investigation. # NOTE(rob): discovered an IMA with this combination. Needs investigation.
return not moe_parallel_config.use_fi_all2allv_kernels return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
......
...@@ -92,16 +92,6 @@ class FallbackExperts(mk.FusedMoEExpertsModular, ABC): ...@@ -92,16 +92,6 @@ class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
moe_parallel_config moe_parallel_config
) and fallback_cls._supports_parallel_config(moe_parallel_config) ) and fallback_cls._supports_parallel_config(moe_parallel_config)
def supports_chunking(self) -> bool:
assert (
self.experts.supports_chunking()
== self.fallback_experts.supports_chunking()
)
return (
self.experts.supports_chunking()
and self.fallback_experts.supports_chunking()
)
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
assert ( assert (
self.experts.supports_expert_map() self.experts.supports_expert_map()
......
...@@ -83,12 +83,6 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): ...@@ -83,12 +83,6 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
def supports_chunking(self) -> bool:
# This refers to TP chunking; DP chunking is handled separately.
# TODO(shuw@nvidia.com): Set to False to be consistent with
# batched_deep_gemm_moe
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. # Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate() return TopKWeightAndReduceDelegate()
......
...@@ -195,10 +195,6 @@ class FlashInferExperts(mk.FusedMoEExpertsModular): ...@@ -195,10 +195,6 @@ class FlashInferExperts(mk.FusedMoEExpertsModular):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
def supports_chunking(self) -> bool:
# This refers to TP chunking; DP chunking is handled separately.
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP() return TopKWeightAndReduceNoOP()
......
...@@ -712,9 +712,6 @@ class NaiveBatchedExperts(mk.FusedMoEExpertsModular): ...@@ -712,9 +712,6 @@ class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
"This method should not be called." "This method should not be called."
) )
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
...@@ -957,9 +954,6 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular): ...@@ -957,9 +954,6 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment