Unverified Commit a57c8228 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 1ee95841
......@@ -620,6 +620,7 @@ def make_modular_kernel(
modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize,
fused_experts=fused_experts,
inplace=False,
)
return modular_kernel
......
......@@ -74,7 +74,11 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
mk_triton = FusedMoEModularKernel(
prep_finalize,
triton_experts,
inplace=False,
)
out_triton = mk_triton(
hidden_states=a,
......@@ -82,7 +86,6 @@ def test_batched_deepgemm_vs_triton(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
global_num_experts=E,
)
......@@ -93,7 +96,11 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
mk_deepgemm = FusedMoEModularKernel(
prep_finalize,
deepgemm_experts,
inplace=False,
)
out_deepgemm = mk_deepgemm(
hidden_states=a,
......@@ -101,7 +108,6 @@ def test_batched_deepgemm_vs_triton(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
global_num_experts=E,
)
......
......@@ -9,6 +9,7 @@ from tests.kernels.moe.utils import (
make_dummy_moe_config,
make_test_quant_config,
make_test_weights,
modular_triton_fused_moe,
)
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
......@@ -26,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
......@@ -261,6 +259,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
inplace=False,
)
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
......
......@@ -207,6 +207,7 @@ def run_with_expert_maps(
),
quant_config=new_quant_config,
),
inplace=False,
)
out_tensor = out_tensor + kernel(**kwargs)
......@@ -266,6 +267,7 @@ def run_8_bit(
),
quant_config=quant_config,
),
inplace=False,
)
return kernel(**kwargs)
......
......@@ -194,8 +194,11 @@ def make_ll_modular_kernel(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
return FusedMoEModularKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
)
def make_ht_modular_kernel(
......@@ -224,8 +227,11 @@ def make_ht_modular_kernel(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
return FusedMoEModularKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
)
def make_modular_kernel(
......@@ -318,7 +324,6 @@ def deepep_deepgemm_moe_impl(
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
......
......@@ -179,7 +179,11 @@ def make_modular_kernel(
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
mk = FusedMoEModularKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
)
return mk
......@@ -256,7 +260,6 @@ def deep_ep_moe_impl(
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
......
......@@ -115,6 +115,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
inplace=False,
)
# triton reference
......@@ -135,7 +136,6 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
......
......@@ -301,6 +301,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)
flashinfer_cutlass_output = kernel(
......@@ -309,7 +310,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=e,
expert_map=None,
......
......@@ -108,6 +108,7 @@ def test_flashinfer_fp4_moe_no_graph(
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
inplace=False,
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
......
......@@ -180,7 +180,11 @@ def oai_triton_moe_impl(
else:
fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config)
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
mk = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
fused_experts,
inplace=False,
)
return mk.forward(
hidden_states=x,
......@@ -188,7 +192,6 @@ def oai_triton_moe_impl(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation="swigluoai",
global_num_experts=num_experts,
expert_map=None,
......
......@@ -18,7 +18,11 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe, make_dummy_moe_config
from tests.kernels.moe.utils import (
fused_moe,
make_dummy_moe_config,
modular_triton_fused_moe,
)
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
......@@ -36,9 +40,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
batched_fused_marlin_moe,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias,
)
......
......@@ -95,6 +95,7 @@ def test_cutlass_fp4_moe_no_graph(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
inplace=False,
)
cutlass_output = kernel(
......
......@@ -172,6 +172,7 @@ def pplx_cutlass_moe(
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
inplace=False,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
......
......@@ -592,6 +592,7 @@ def pplx_moe(
prepare_finalize,
experts,
shared_experts,
inplace=False,
)
# Note: for now use_compile will error out if the problem size is
......
......@@ -7,7 +7,11 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe import (
TritonExperts,
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
......@@ -20,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.deep_gemm import per_block_cast_to_fp8
from vllm.utils.math_utils import round_up
......@@ -116,6 +123,7 @@ def batched_moe(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
inplace=False,
)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
......@@ -157,6 +165,7 @@ def naive_batched_moe(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
inplace=False,
)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
......@@ -554,3 +563,16 @@ def make_shared_experts(
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
finally:
torch.set_default_dtype(old_dtype)
def modular_triton_fused_moe(
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None,
) -> FusedMoEModularKernel:
return FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(moe_config, quant_config),
shared_experts,
inplace=False,
)
......@@ -1083,13 +1083,16 @@ class FusedMoEConfig:
router_logits_dtype: torch.dtype | None = None
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
is_act_and_mul: bool = True
is_lora_enabled: bool = False
# This flag is used to disable the inplace optimization
# in MoE kernels. If this flag is True then the kernel
# should not be using inplace. If the flag is false, the
# kernel is free to use inplace or not.
disable_inplace: bool = True
def __post_init__(self):
if self.dp_size > 1:
logger.debug_once(
......
......@@ -1165,6 +1165,7 @@ def cutlass_moe_w4a8_fp8(
quant_config=quant_config,
group_size=group_size,
),
inplace=False,
)
return fn(
......
......@@ -267,6 +267,7 @@ def fused_marlin_moe(
if inplace:
assert output is None, "Conflicting request"
assert not disable_inplace()
quant_type = ScalarType.from_id(quant_type_id)
assert quant_type in [
......@@ -356,10 +357,7 @@ def fused_marlin_moe(
).view(-1, topk, K)
if output is None:
if inplace and not disable_inplace():
output = hidden_states
else:
output = torch.empty_like(hidden_states)
output = hidden_states if inplace else torch.empty_like(hidden_states)
if moe_sum is None:
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
......
......@@ -27,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
......@@ -1511,7 +1508,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if inplace and not disable_inplace():
if inplace:
return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts
......@@ -1534,6 +1531,8 @@ def fused_experts(
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
assert not inplace or not disable_inplace()
return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states,
w1=w1,
......@@ -1593,7 +1592,7 @@ def fused_experts_impl(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
inplace: bool,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
......@@ -1712,10 +1711,7 @@ def fused_experts_impl(
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace and not disable_inplace():
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True
......@@ -2291,15 +2287,3 @@ class TritonWNA16Experts(TritonExperts):
# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def modular_triton_fused_moe(
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(moe_config, quant_config),
shared_experts,
)
......@@ -113,10 +113,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def supports_eplb(self) -> bool:
return False
@property
def allow_inplace(self) -> bool:
return False
@property
def method_name(self) -> str:
return self.__class__.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment