Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
......@@ -9,6 +9,8 @@ import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
......@@ -143,10 +145,16 @@ def pplx_cutlass_moe(
device="cuda",
dtype=torch.int64)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, per_act_token, per_out_ch,
ab_strides1, ab_strides2, c_strides1,
c_strides2)
experts = CutlassBatchedExpertsFp8(
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
ab_strides2, c_strides1, c_strides2,
fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank]))
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
......@@ -167,10 +175,7 @@ def pplx_cutlass_moe(
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank])
)
torch.cuda.synchronize()
......
......@@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
]
PPLX_COMBOS = [
# TODO: figure out why this fails, seems to be test problem
# TODO(bnell): figure out why this fails, seems to be test problem
#(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
......@@ -360,18 +360,18 @@ def pplx_prepare_finalize(
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
a_chunk,
a1_scale,
a2_scale,
chunk_topk_weight,
chunk_topk_ids,
num_experts,
None,
False,
FusedMoEQuantConfig(
FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant,
False,
block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=block_shape,
a1_scale=a1_scale,
a2_scale=a2_scale,
),
)
......@@ -540,20 +540,6 @@ def pplx_moe(
topk_ids = topk_ids.to(dtype=torch.uint32)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
......@@ -567,6 +553,28 @@ def pplx_moe(
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=quant_config,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
......@@ -585,10 +593,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
if use_cudagraphs:
......@@ -605,10 +609,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
torch.cuda.synchronize()
......@@ -820,7 +820,7 @@ def test_pplx_moe_slow(
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
......@@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
args["w1"] = w1
args["w2"] = w2
......
......@@ -7,10 +7,12 @@ import itertools
import pytest
import torch
from tests.kernels.moe.utils import fused_moe
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (9, 0):
......@@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
score,
topk,
renormalize=False,
use_fp8_w8a8=True, # using fp8
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
quant_config=fp8_w8a8_moe_quant_config(
per_act_token_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
),
)
# Check results
......
......@@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
......@@ -34,18 +35,22 @@ def triton_moe(
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_act_token_quant,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape)
quant_config=quant_config)
def batched_moe(
......@@ -64,6 +69,16 @@ def batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
......@@ -72,21 +87,11 @@ def batched_moe(
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
quant_config=quant_config,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def naive_batched_moe(
......@@ -105,6 +110,16 @@ def naive_batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
......@@ -113,21 +128,11 @@ def naive_batched_moe(
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
quant_config=quant_config,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
......@@ -216,7 +221,7 @@ def make_test_weight(
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
......@@ -228,7 +233,7 @@ def make_test_weight(
w_gs_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
......@@ -258,16 +263,16 @@ def make_test_weights(
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
return (
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
per_out_ch_quant),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
per_out_ch_quant),
)
......@@ -285,6 +290,76 @@ def per_token_cast_to_fp8(
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def make_test_quant_config(
e: int,
n: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype,
quant_dtype,
per_out_ch_quant=per_act_token_quant,
block_shape=block_shape,
)
# Hacky/trivial scales for nvfp4.
a1_gscale: Optional[torch.Tensor] = None
a2_gscale: Optional[torch.Tensor] = None
if quant_dtype == "nvfp4":
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a1_scale = a1_gscale
a2_scale = a2_gscale
else:
a1_scale = None
a2_scale = None
return w1, w2, FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_s,
w2_scale=w2_s,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
a1_scale=a1_scale,
a2_scale=a2_scale,
# TODO: make sure this is handled properly
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
renormalize: bool = False,
quant_config: Optional[FusedMoEQuantConfig] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=quant_config)
# CustomOp?
class BaselineMM(torch.nn.Module):
......
......@@ -8,7 +8,8 @@ import pytest
import torch
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8)
from vllm.platforms import current_platform
......@@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
topk_ids):
"""This function performs fused moe with per-column int8 quantization
using native torch."""
......@@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
......@@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
topk_weights, topk_ids)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
out = fused_moe(
quant_config = FusedMoEQuantConfig.make(
torch.int8,
per_act_token_quant=True,
block_shape=None,
w1_scale=w1_s,
w2_scale=w2_s,
)
out = fused_experts(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True, # Using int8-w8a8
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
topk_weights,
topk_ids,
quant_config=quant_config,
)
# Check results
......
......@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
......@@ -36,6 +37,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"activation_without_mul",
"override_config",
"get_config",
]
......@@ -43,7 +45,6 @@ __all__ = [
if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
......@@ -56,13 +57,12 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts, fused_experts, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
TritonExperts, fused_experts, fused_topk, get_config_file_name,
grouped_topk)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
__all__ += [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
......
......@@ -8,6 +8,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
......@@ -212,27 +214,20 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# The Deep Gemm kernels only support block size of 128
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
def __init__(self,
max_num_tokens: int,
num_dispatchers: int,
block_shape: list[int],
per_act_token_quant=False):
def __init__(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
num_dispatchers: The number of DP dispatchers.
block_shape: Block quantization block shape.
per_act_token_quant: Per activation token quantization flag.
quant_config: Quantization configuration
"""
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
super().__init__(quant_config)
assert self.block_shape == deep_gemm_block_shape()
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
......@@ -290,12 +285,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -321,11 +311,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale),
workspace1, expert_num_tokens, expected_m)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
workspace1, expert_num_tokens)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
expert_num_tokens, expected_m)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale),
output, expert_num_tokens, expected_m)
......@@ -8,55 +8,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self,
max_num_tokens: int,
num_dispatchers: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
allow_deep_gemm: bool = False):
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
))
def __init__(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
allow_deep_gemm: bool = False,
):
super().__init__(quant_config)
self.batched_triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape,
quant_config=self.quant_config,
)
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8
and self.block_shape
== BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE)
self.allow_deep_gemm = (allow_deep_gemm
and self.quant_config.use_fp8_w8a8 and
self.block_shape == deep_gemm_block_shape())
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
block_shape=self.block_shape, # type: ignore[arg-type]
quant_config=self.quant_config,
) if self.allow_deep_gemm else None
assert (self.batched_deep_gemm_experts is not None
......@@ -143,12 +125,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -158,7 +135,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta,
activation, global_num_experts, expert_map, a1q_scale,
workspace13, workspace2, expert_tokens_meta,
apply_router_weight_on_input)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import cdiv
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.utils import cdiv, has_triton_kernels
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if TYPE_CHECKING and has_triton_kernels:
from triton_kernels.matmul_ogs import PrecisionConfig
logger = init_logger(__name__)
def _get_quant_config_quantization_args(
quant_config: Optional[QuantizationConfig],
prop_name: str,
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(prop_name)
def _get_config_dtype_str(
dtype: torch.dtype,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
) -> Optional[str]:
"""
Return a string used to construct the filename that contains the
tuning info for a particular quantization scheme. See
try_get_optimal_moe_config in fused_moe.py.
"""
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_mxfp4_w4a4:
return "mxfp4_w4a4"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None
def _quant_flags_to_group_shape(
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[Optional[GroupShape], Optional[GroupShape]]:
"""
Convert MoE quantization flags into more generic GroupShapes.
"""
a_shape: Optional[GroupShape]
w_shape: Optional[GroupShape]
if block_shape is not None:
assert not per_act_token_quant
assert not per_out_ch_quant
# TODO(bnell): this is not quite right for activations since first
# dim should be 1.
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
else:
return None
w_shape = None
a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR
if per_act_token_quant:
a_shape = GroupShape.PER_TOKEN
def get_quant_config_input_quant(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
return _get_quant_config_quantization_args(quant_config,
"input_activations")
if per_out_ch_quant:
w_shape = GroupShape.PER_TOKEN
return a_shape, w_shape
def get_quant_config_weight_quant(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
return _get_quant_config_quantization_args(quant_config, "weights")
@dataclass
class FusedMoEQuantDesc:
"""
A quantization descriptor for fused MoE ops. This class can describe
either activations or weights.
"""
# The quantized type of this parameters. None means unquantized or
# already quantized.
# TODO (bnell): use scalar_type instead of Union.
dtype: Union[torch.dtype, str, None] = None
def get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_mxfp4_w4a4: bool,
) -> Union[None, torch.dtype, str]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
elif use_mxfp4_w4a4:
return "mxfp4"
return None
# A field that describes the quantization group shape, from quant_utils.py.
# * (-1, -1) for per-tensor quantization
# * (1, -1) for per-row quantization
# * (-1, 1) for per-column quantization
# * (128, 128) for 128x128 deepseek style block quantization
# * (1, 128) for deepseek style activation quantization
# (i.e. per-token-per-group)
shape: Optional[GroupShape] = None
# Quantization scales.
# TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc?
scale: Union[torch.Tensor, "PrecisionConfig", None] = None
# Quantization alphas or gscales, used for nvfp4 types.
# TODO(bnell): put some of these in subclasses
alpha_or_gscale: Optional[torch.Tensor] = None
# Zero points for int4/int8 types
zp: Optional[torch.Tensor] = None
# Biases for GPT triton MoE
bias: Optional[torch.Tensor] = None
# TODO(bnell): have subclasses for specific moe methods?
# e.g. for specific arguments bias, precision, etc.
@dataclass
class FusedMoEQuantConfig:
# The post quantization activation type.
# TODO (bnell): use scalar_type instead of Union.
quant_dtype: Union[torch.dtype, str, None] = None
per_act_token_quant: bool = False
per_out_ch_quant: bool = False
block_shape: Optional[list[int]] = None
# TODO: add col major flag?
# add detailed quant info for input, intermediates, weights, etc?
"""
The FusedMoEQuantConfig contains all the quantization parameters for
a single FusedMoEMethodBase operation. It consists of four
FusedMoEQuantDescs, one for each activation and set of weights.
Each FusedMoEMethodBase must implement a get_fused_moe_quant_config
method to construct a FusedMoEQuantConfig for use with that class.
FusedMoEQuant configs are only used for modular kernels, fused_experts
(from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and
triton_kernel_moe_forward. Other MoE methods can ignore the
FusedMoEQuantConfig (for now) and hardcode it to None.
There are currently some restrictions on what can be expressed:
- Most MoE ops only support similar quantization strategies for
each parameter, e.g. both weights must have the same GroupShape
and both activations must share the same GroupShape. One exception to
this is the cutlass moe which allows per channel quantization on the
outputs. Note: this restrictions are not always rigorously checked.
- Not all fused MoE functions support all the parameters, e.g. zero points,
global scales, alphas and biases are not universally supported.
- Fully general GroupShapes are not allowed. Activations only support
per token, per tensor or K-blocked.
- Weights are not required to have a GroupShape since they have already
been quantized.
Other notes:
- PrecisionConfigs are specific to GPT OSS Triton.
- As a follow up it would probably make sense to subclass FusedMoEQuantDesc
or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses
so that only the required quantization parameters are used/stored.
"""
# TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking
_a1: FusedMoEQuantDesc
_a2: FusedMoEQuantDesc
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc
def __post_init__(self):
assert (not self.per_act_token_quant
or self.block_shape is None), "illegal quantization"
#
# Convenience accessors for various properties.
#
@property
def quant_dtype(self) -> Union[torch.dtype, str, None]:
return self._a1.dtype
@property
def is_quantized(self) -> bool:
return self.quant_dtype is not None
@property
def is_per_act_token(self) -> bool:
return self.per_act_token_quant
return self._a1.shape == GroupShape.PER_TOKEN
@property
def per_act_token_quant(self) -> bool:
return self._a1.shape == GroupShape.PER_TOKEN
@property
def per_out_ch_quant(self) -> bool:
return self._w1.shape == GroupShape.PER_TOKEN
@property
def is_per_tensor(self) -> bool:
return self._a1.shape == GroupShape.PER_TENSOR
@property
def block_shape(self) -> Optional[list[int]]:
if (self._a1.shape is not None
and self._a1.shape != GroupShape.PER_TENSOR
and self._a1.shape != GroupShape.PER_TOKEN):
return [self._a1.shape.row, self._a1.shape.col]
else:
return None
@property
def is_block_quantized(self) -> bool:
return self.block_shape is not None
@property
def is_per_tensor(self) -> bool:
return not self.per_act_token_quant and self.block_shape is None
def a1_scale(self) -> Optional[torch.Tensor]:
assert self._a1.scale is None or isinstance(self._a1.scale,
torch.Tensor)
return self._a1.scale
@property
def a1_gscale(self) -> Optional[torch.Tensor]:
return self._a1.alpha_or_gscale
@property
def a2_scale(self) -> Optional[torch.Tensor]:
assert self._a2.scale is None or isinstance(self._a2.scale,
torch.Tensor)
return self._a2.scale
@property
def a2_gscale(self) -> Optional[torch.Tensor]:
return self._a2.alpha_or_gscale
@property
def w1_scale(self) -> Optional[torch.Tensor]:
assert self._w1.scale is None or isinstance(self._w1.scale,
torch.Tensor)
return self._w1.scale
@property
def w1_zp(self) -> Optional[torch.Tensor]:
return self._w1.zp
@property
def w1_bias(self) -> Optional[torch.Tensor]:
return self._w1.bias
@property
def w1_precision(self) -> Optional["PrecisionConfig"]:
assert self._w1.scale is None or isinstance(self._w1.scale,
PrecisionConfig)
return self._w1.scale
@property
def g1_alphas(self) -> Optional[torch.Tensor]:
return self._w1.alpha_or_gscale
@property
def w2_scale(self) -> Optional[torch.Tensor]:
assert self._w2.scale is None or isinstance(self._w2.scale,
torch.Tensor)
return self._w2.scale
@property
def w2_zp(self) -> Optional[torch.Tensor]:
return self._w2.zp
@property
def w2_bias(self) -> Optional[torch.Tensor]:
return self._w2.bias
@property
def w2_precision(self) -> Optional["PrecisionConfig"]:
assert self._w2.scale is None or isinstance(self._w2.scale,
PrecisionConfig)
return self._w2.scale
@property
def g2_alphas(self) -> Optional[torch.Tensor]:
return self._w2.alpha_or_gscale
@property
def use_fp8_w8a8(self) -> bool:
return self.quant_dtype == torch.float8_e4m3fn
@property
def use_int8_w8a8(self) -> bool:
return self.quant_dtype == torch.int8
@property
def use_int8_w8a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == torch.int8)
@property
def use_int4_w4a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == "int4")
@property
def use_mxfp4_w4a4(self) -> bool:
return self.quant_dtype == "mxfp4"
@property
def use_nvfp4_w4a4(self) -> bool:
return self.quant_dtype == "nvfp4"
def config_name(self, dtype: torch.dtype) -> Optional[str]:
"""
Return a string used to construct the filename that contains the
tuning info for a particular quantization scheme. See
try_get_optimal_moe_config in fused_moe.py.
"""
return _get_config_dtype_str(
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=dtype,
)
def scale_shape(
self,
max_tokens: int,
hidden_dim: int,
) -> Optional[tuple[int, int]]:
"""
Construct the proper activation scale shape for this
config.
"""
if self.is_quantized:
if self.is_block_quantized:
assert self.block_shape is not None
......@@ -117,6 +336,10 @@ class FusedMoEQuantConfig:
max_tokens: int,
hidden_dim: int,
) -> Optional[tuple[int, int, int]]:
"""
Construct the proper activation batched scale shape for this
config, e.g. (num experts, *scale_shape).
"""
if self.is_quantized:
scale_shape = self.scale_shape(max_tokens, hidden_dim)
assert scale_shape is not None
......@@ -126,38 +349,218 @@ class FusedMoEQuantConfig:
@staticmethod
def make(
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
quant_dtype: Union[torch.dtype, str, None] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None,
w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
g1_alphas: Optional[torch.Tensor] = None,
g2_alphas: Optional[torch.Tensor] = None,
a1_gscale: Optional[torch.Tensor] = None,
a2_gscale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
) -> "FusedMoEQuantConfig":
assert sum([
int(flag) for flag in [
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_mxfp4_w4a4,
]
]) <= 1, "Quantization flags are mutually exclusive."
quant_dtype = get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
)
return FusedMoEQuantConfig(
quant_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
"""
General builder function for a FusedMoEQuantConfig.
- quant_dtype: Optional quantization type. None if activations are
unquantized or quantized prior to calling. Note: "nvfp4" and
"mxfp4" are the only valid string values for quant_dtype.
- per_act_token_quant: Activations have per token quantization.
- per_out_ch_quant: Outputs have per channel quantization. (only
for cutlass).
- block_shape: Optional block size for block-wise quantization.
Incompatible with per_act_token and per_out_ch quant.
- w1_scale: Optional scale to be used for w1.
- w2_scale: Optional scale to be used for w2.
- a1_scale: Optional scale to be used for a1.
- a2_scale: Optional scale to be used for a2.
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
- w1_bias: Optional biases for w1 (GPT OSS Triton).
- w2_bias: Optional biases for w1 (GPT OSS Triton).
- w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization.
"""
assert (not isinstance(quant_dtype, str) or quant_dtype == "nvfp4"
or quant_dtype == "mxfp4")
a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape)
quant_config = FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale),
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale),
_w1=FusedMoEQuantDesc(quant_dtype, w_shape, w1_scale, g1_alphas,
w1_zp, w1_bias),
_w2=FusedMoEQuantDesc(quant_dtype, w_shape, w2_scale, g2_alphas,
w2_zp, w2_bias),
)
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
assert quant_config.block_shape == block_shape
return quant_config
def fp8_w8a8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and fp8 weights.
"""
return FusedMoEQuantConfig.make(torch.float8_e4m3fn,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape)
def int8_w8a8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
per_act_token_quant: bool = False,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for int8 activations and int8 weights.
"""
return FusedMoEQuantConfig.make(
torch.int8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=None,
)
def mxfp4_w4a4_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
"""
return FusedMoEQuantConfig.make(
"mxfp4",
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=block_shape,
)
def nvfp4_moe_quant_config(
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and nvp4 weights.
"""
return FusedMoEQuantConfig.make(
"nvfp4",
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=None,
)
def int4_w4a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
block_shape: Optional[list[int]] = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and int4 weights.
Note: Activations are pre-quantized.
"""
group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(shape=group_shape),
_a2=FusedMoEQuantDesc(shape=group_shape),
_w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp),
_w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp),
)
def int8_w8a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
block_shape: Optional[list[int]] = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and int8 weights.
Note: Activations are pre-quantized.
"""
group_shape = GroupShape(*block_shape) if block_shape is not None else None
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(shape=group_shape),
_a2=FusedMoEQuantDesc(shape=group_shape),
_w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp),
_w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp),
)
def biased_moe_quant_config(
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations with biases.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(),
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc(bias=w1_bias),
_w2=FusedMoEQuantDesc(bias=w2_bias),
)
# A FusedMoEQuantConfig constant for an unquantized MoE op.
FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
@dataclass
......@@ -315,8 +718,6 @@ class FusedMoEConfig:
# The activation type.
in_dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig] = None
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
......@@ -328,34 +729,6 @@ class FusedMoEConfig:
assert self.max_num_tokens > 0
@property
def quant_dtype(self) -> Union[torch.dtype, str, None]:
if self.quant_config is not None:
return self.quant_config.quant_dtype
else:
return None
@property
def block_shape(self) -> Optional[list[int]]:
if self.quant_config is not None:
return self.quant_config.block_shape
else:
return None
@property
def per_act_token_quant(self) -> bool:
if self.quant_config is not None:
return self.quant_config.per_act_token_quant
else:
return False
@property
def per_out_ch_quant(self) -> bool:
if self.quant_config is not None:
return self.quant_config.per_out_ch_quant
else:
return False
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
......@@ -401,97 +774,6 @@ class FusedMoEConfig:
"""
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return (self.quant_config is not None
and self.quant_config.quant_dtype == "nvfp4"
and envs.VLLM_USE_FLASHINFER_MOE_FP4
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod
def make(
num_experts: int,
experts_per_token: int,
hidden_dim: int,
num_local_experts: int,
moe_parallel_config: FusedMoEParallelConfig,
in_dtype: torch.dtype,
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config: Optional[Union[FusedMoEQuantConfig,
QuantizationConfig]] = None,
has_bias: bool = False,
) -> "FusedMoEConfig":
_quant_config: Optional[FusedMoEQuantConfig] = None
if quant_config is not None and isinstance(quant_config,
QuantizationConfig):
if hasattr(quant_config, 'weight_block_size'):
block_shape = quant_config.weight_block_size
else:
block_shape = None
per_act_token_quant = False
per_out_ch_quant = False
quant_dtype: Union[torch.dtype, str, None] = None
input_quant = get_quant_config_input_quant(quant_config)
weight_quant = get_quant_config_weight_quant(quant_config)
if input_quant is not None:
per_act_token_quant = (input_quant.strategy
== QuantizationStrategy.TOKEN
if input_quant is not None else False)
if input_quant.num_bits == 8:
if input_quant.type == QuantizationType.FLOAT:
quant_dtype = torch.float8_e4m3fn
elif input_quant.type == QuantizationType.INT:
quant_dtype = torch.int8
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Config)
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
quant_dtype = "mxfp8"
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
ModelOptNvFp4Config):
quant_dtype = "nvfp4"
if weight_quant is not None:
per_out_ch_quant = (
weight_quant.strategy == QuantizationStrategy.CHANNEL)
if quant_dtype is not None:
_quant_config = FusedMoEQuantConfig(
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
)
else:
_quant_config = FusedMoEQuantConfig()
if moe_parallel_config.dp_size > 1:
logger.warning_once("MoE DP setup unable to determine "
"quantization scheme or unsupported "
"quantization type. This model will "
"not run with DP enabled.")
else:
_quant_config = quant_config
return FusedMoEConfig(
num_experts=num_experts,
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
num_local_experts=num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=in_dtype,
quant_config=_quant_config,
max_num_tokens=max_num_tokens,
has_bias=has_bias,
)
......@@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
assert quant_config.use_fp8_w8a8
super().__init__(quant_config)
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
......@@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
......@@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
self.c_strides1, self.c_strides2, workspace13, workspace2,
expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format, topk_weights)
......@@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
quant_config,
)
@property
......@@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
quant_config,
)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
......@@ -414,16 +395,12 @@ def cutlass_moe_fp8(
w2_q: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
......@@ -475,10 +452,18 @@ def cutlass_moe_fp8(
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
if per_act_token is None:
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.size(0)
assert quant_config is not None
if quant_config.a1_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a1_scale.numel() != 1)
if quant_config.a2_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a2_scale.numel() != 1)
assert (quant_config.w1_scale is None
or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1)
== w1_q.size(1))))
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
0)
......@@ -487,12 +472,11 @@ def cutlass_moe_fp8(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
),
)
......@@ -502,14 +486,9 @@ def cutlass_moe_fp8(
w2_q,
topk_weights,
topk_ids,
False,
activation,
num_experts,
expert_map,
w1_scale,
w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
activation=activation,
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -542,7 +521,7 @@ def run_cutlass_moe_fp4(
) -> None:
"""
MoE implementation for FP4 Inputs
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
......@@ -552,16 +531,16 @@ def run_cutlass_moe_fp4(
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8
m, n, k: Unquantized weight shapes, dtype: int
e: number of experts, dtype: int
......@@ -652,42 +631,21 @@ def run_cutlass_moe_fp4(
return
# Split into batched and non-batched
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
use_batched_format: bool = False,
):
super().__init__(
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
FusedMoEQuantConfig(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
super().__init__(quant_config)
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
# TODO(bnell): put this stuff into quant config?
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
@property
def activation_formats(
self
......@@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor,
a1q_scale: Optional[torch.Tensor], # unused
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
a=hidden_states,
a1_gscale=self.a1_gscale,
w1_fp4=w1,
w1_blockscale=w1_scale,
w1_blockscale=self.w1_scale,
w1_alphas=self.g1_alphas,
a2_gscale=self.a2_gscale,
w2_fp4=w2,
w2_blockscale=w2_scale,
w2_blockscale=self.w2_scale,
w2_alphas=self.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
......@@ -788,14 +741,9 @@ def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int,
n: int,
k: int,
......@@ -805,17 +753,31 @@ def cutlass_moe_fp4(
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
# TODO(bnell): this feels a bit hacky
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
quant_config = FusedMoEQuantConfig.make(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=quant_config.per_act_token_quant,
per_out_ch_quant=quant_config.per_out_ch_quant,
block_shape=quant_config.block_shape,
g1_alphas=quant_config.g1_alphas,
g2_alphas=quant_config.g2_alphas,
a1_gscale=quant_config.a1_gscale,
a2_gscale=quant_config.a2_gscale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
g1_alphas,
g2_alphas,
a1_gscale,
a2_gscale,
max_experts_per_worker=e,
out_dtype=a.dtype,
per_act_token_quant=False,
per_out_ch_quant=False,
quant_config=quant_config,
use_batched_format=False,
),
)
......@@ -830,10 +792,6 @@ def cutlass_moe_fp4(
activation="silu",
global_num_experts=e,
expert_map=None,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
return True
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
def run_cutlass_block_scaled_fused_experts(
a: torch.Tensor,
w1: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Optional
import torch
......@@ -9,9 +8,11 @@ from tqdm import tqdm
import vllm.envs as env
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce)
compute_aligned_M, deep_gemm_block_shape, deepgemm_moe_permute,
deepgemm_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
......@@ -25,14 +26,6 @@ from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
logger = init_logger(__name__)
@functools.cache
def deep_gemm_block_shape() -> list[int]:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
block = dg.get_m_alignment_for_contiguous_layout()
return [block, block]
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
align = deep_gemm_block_shape()[0]
return align <= M and N % align == 0 and K % align == 0
......@@ -163,13 +156,12 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=deep_gemm_block_shape(),
))
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
assert quant_config.block_shape == deep_gemm_block_shape()
assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
@property
def activation_formats(
......@@ -221,21 +213,17 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert self.block_shape is not None
assert a1q_scale is not None
assert w1_scale is not None
assert w2_scale is not None
assert self.a2_scale is None
assert self.block_shape is not None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states
_, N, K = w1.size()
......@@ -270,7 +258,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
aq_out=a1q_perm)
assert a1q.size(0) == M_sum
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, self.w1_scale),
mm1_out, expert_ids)
self.activation(activation, act_out, mm1_out.view(-1, N))
......@@ -281,7 +269,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
column_major_scales=True,
out_q=quant_out)
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, self.w2_scale),
mm2_out, expert_ids)
if apply_router_weight_on_input:
......@@ -348,9 +336,16 @@ def deep_gemm_moe_fp8(
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=deep_gemm_block_shape())
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
DeepGemmExperts(),
DeepGemmExperts(quant_config),
)
return fn(
hidden_states,
......@@ -358,13 +353,9 @@ def deep_gemm_moe_fp8(
w2,
topk_weights,
topk_ids,
inplace,
activation,
global_num_experts,
expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -183,8 +183,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -204,7 +202,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_scale,
quant_config.a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
......@@ -215,7 +213,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else:
a1q = a1
a1q_scale = None
a1_post_scale = a1_scale
a1_post_scale = quant_config.a1_scale
return (lambda *args: None,
self._do_dispatch(tokens=a1q,
......@@ -229,8 +227,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -238,9 +234,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
(_, receiver) = self.prepare_async(a1, a1_scale, a2_scale,
topk_weights, topk_ids, num_experts,
expert_map,
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()
......
......@@ -77,15 +77,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def _do_quant(
self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_k = block_shape[1] if block_shape is not None else None
if self.use_fp8_dispatch:
block_k = quant_config.block_shape[
1] if quant_config.block_shape is not None else None
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
# DeepEP kernels did the quantization for us.
x, x_scales = x
......@@ -101,12 +99,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
per_act_token_quant,
block_shape)
x, x_scales = moe_kernel_quantize_input(
x, quant_config.a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
x = x.view((num_experts, -1, hidden_dim))
if quant_dtype is not None:
if quant_config.quant_dtype is not None:
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
......@@ -118,8 +116,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -139,9 +135,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert hidden_size % 128 == 0, \
"DeepEP kernels quantize the inputs in blocks of shape 128"
has_per_token_scales = a1_scale.numel(
) != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
has_per_token_scales = quant_config.a1_scale.numel(
) != 1 if quant_config.a1_scale is not None else (
quant_config.a2_scale.numel() != 1
if quant_config.a2_scale is not None else False)
assert not has_per_token_scales, (
"low_latency kernels doesn't support dispatching per-token scales")
......@@ -163,20 +160,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook=True)
self.handles[a2a_idx] = handle
return (hook, lambda: self._receiver(expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config))
return (
hook,
lambda: self._receiver(expert_x, expert_num_tokens, quant_config.
a1_scale, a1.dtype, quant_config))
def _receiver(
self,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor,
a1_scale,
a1_dtype,
a1_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype,
quant_config)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
......@@ -186,8 +184,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -195,8 +191,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(a1, a1_scale, a2_scale,
topk_weights, topk_ids,
hook, receiver = self.prepare_async(a1, topk_weights, topk_ids,
num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import Optional
import torch
......@@ -44,33 +44,20 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
out_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
quant_config: FusedMoEQuantConfig,
ep_rank: int = 0,
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=quant_dtype,
per_act_token_quant=False,
block_shape=None,
))
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
"Only nvfp4,fp8 quantization are currently supported.")
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
self.tp_size = tp_size
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
self.out_dtype = out_dtype
@property
......@@ -141,12 +128,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], # Not used
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -162,17 +144,17 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
fc2_expert_weights = w2
else:
# Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, (
assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not "
"be None for FlashInferExperts")
# Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale.
quant_scales = [
self.a1_gscale,
w1_scale.view(torch.int32),
self.w1_scale.view(torch.int32),
self.g1_alphas,
self.a2_gscale,
w2_scale.view(torch.int32),
self.w2_scale.view(torch.int32),
self.g2_alphas,
]
# FlashInfer API requires weight to be long for nvfp4
......@@ -202,12 +184,7 @@ def flashinfer_cutlass_moe_fp4(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
......@@ -216,15 +193,10 @@ def flashinfer_cutlass_moe_fp4(
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False,
a1_gscale=a1_gscale),
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False),
FlashInferExperts(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=hidden_states.dtype,
quant_dtype="nvfp4",
quant_config=quant_config,
))
return fused_experts(
......@@ -237,7 +209,5 @@ def flashinfer_cutlass_moe_fp4(
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -22,13 +22,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
use_dp: bool,
a1_gscale: Optional[torch.Tensor],
num_dispatchers: int = 1,
):
super().__init__()
self.num_dispatchers_ = num_dispatchers
self.use_dp = use_dp
self.a1_gscale = a1_gscale
self.local_tokens = None
@property
......@@ -47,14 +45,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], # Not used
a2_scale: Optional[torch.Tensor], # Not used
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
......@@ -67,7 +62,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
self.a1_gscale,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List # noqa: UP035
from typing import Optional
import torch
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import direct_register_custom_op
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: List[int], #noqa: UP006
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0) -> torch.Tensor:
return torch.empty_like(x)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)
from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
top_k, num_experts),
routing_method_type=routing_method_type)
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
return torch.empty_like(hidden_states)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
......@@ -8,7 +8,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
from vllm.model_executor.layers.fused_moe.utils import (
......@@ -498,8 +498,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -545,14 +543,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dtype=torch.float32,
device=a1.device)
else:
assert a1_scale is None
assert quant_config.a1_scale is None
b_a1_scale = None
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
a1_scale = normalize_scales_shape(a1_scale)
a2_scale = normalize_scales_shape(a2_scale)
a1_scale = normalize_scales_shape(quant_config.a1_scale)
for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
......@@ -623,28 +620,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self,
max_num_tokens: int,
num_dispatchers: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
super().__init__(quant_config)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
......@@ -705,12 +687,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -740,10 +717,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
tmp = _resize_cache(workspace2, (num, N))
if self.quant_config.is_quantized:
assert a1q_scale is not None and w1_scale is not None
assert a1q_scale is not None and self.w1_scale is not None
input = self.dequant(hidden_states[expert, :, :],
a1q_scale[expert])
w1_dq = self.dequant(w1[expert], w1_scale[expert])
w1_dq = self.dequant(w1[expert], self.w1_scale[expert])
input = input[:num] @ w1_dq.transpose(0, 1)
else:
input = hidden_states[expert, :num, :] @ w1[expert].transpose(
......@@ -752,8 +729,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, tmp, input.to(tmp.dtype))
if self.quant_config.is_quantized:
assert w2_scale is not None
w2_dq = self.dequant(w2[expert], w2_scale[expert])
assert self.w2_scale is not None
w2_dq = self.dequant(w2[expert], self.w2_scale[expert])
else:
w2_dq = w2[expert]
......@@ -840,35 +817,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self,
max_num_tokens: int,
num_dispatchers: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
super().__init__(quant_config)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
assert max_num_tokens > 0
assert num_dispatchers > 0
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int8_w8a8 = use_int8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
......@@ -921,19 +878,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.use_int4_w4a16:
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
......@@ -958,11 +910,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w1.size(0) == E
assert w2.size(0) == E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=hidden_states.dtype)
config_dtype = self.quant_config.config_name(hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.size(),
......@@ -992,7 +940,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache2 = _resize_cache(workspace2,
(E, max_num_tokens, N // 2))
if self.use_fp8_w8a8:
# TODO(bnell): should this be done for any quantized type?
if self.quant_config.use_fp8_w8a8:
intermediate_cache1.fill_(0)
a1q_scale = normalize_batched_scales_shape(a1q_scale, E)
......@@ -1005,11 +954,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a1q_scale,
B_scale=w1_scale,
B_zp=w1_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
B_scale=self.w1_scale,
B_zp=self.w1_zp,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
config=config,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape)
......@@ -1021,7 +970,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache1.view(-1, N))
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
intermediate_cache2, a2_scale, max_num_tokens, E, N,
intermediate_cache2, self.a2_scale, max_num_tokens, E, N,
expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
self.block_shape)
......@@ -1032,11 +981,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a2q_scale,
B_scale=w2_scale,
B_zp=w2_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
B_scale=self.w2_scale,
B_zp=self.w2_zp,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
config=config,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE kernel."""
"""Fused MoE Triton kernels."""
import functools
import json
import os
# torch.compile needs typing.List. It will fail torch.library.infer_schema
# otherwise
from typing import List # noqa: UP035
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
......@@ -18,7 +18,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, get_config_quant_dtype)
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
_valid_cutlass_block_scaled_grouped_gemm,
run_cutlass_block_scaled_fused_experts)
......@@ -32,11 +32,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
_resize_cache, activation_without_mul, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4)
from vllm.platforms import current_platform
......@@ -1049,87 +1045,66 @@ def fused_grouped_topk(
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
def get_config_dtype_str(
dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False,
use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]:
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_mxfp4_w4a4:
return "mxfp4_w4a4"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None
def inplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, #noqa: UP006
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> None:
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, #noqa: UP006
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, is_act_and_mul,
apply_router_weight_on_input, use_fp8_w8a8,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, w1_bias, w2_bias)
def inplace_fused_experts_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> None:
def inplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, #noqa: UP006
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
pass
......@@ -1143,175 +1118,6 @@ direct_register_custom_op(
)
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: List[int], #noqa: UP006
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0) -> torch.Tensor:
return torch.empty_like(x)
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)
from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
top_k, num_experts),
routing_method_type=routing_method_type)
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
pass
direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......@@ -1319,7 +1125,6 @@ def outplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
......@@ -1341,37 +1146,37 @@ def outplace_fused_experts(
) -> torch.Tensor:
return fused_experts_impl(
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant,
global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
def outplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1403,45 +1208,36 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
quant_config: Optional[FusedMoEQuantConfig] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
use_fp8_w8a8 = quant_config.use_fp8_w8a8
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases because they only support
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
if (allow_deep_gemm and use_fp8_w8a8 and
if (allow_deep_gemm and quant_config.use_fp8_w8a8 and
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
assert quant_config is not None
assert apply_router_weight_on_input is False
assert is_act_and_mul, (
"DeepGemm only supports is_act_and_mul=True for now.")
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,
......@@ -1452,22 +1248,23 @@ def fused_experts(hidden_states: torch.Tensor,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
and _valid_cutlass_block_scaled_grouped_gemm(
w1, w2, inplace, activation, apply_router_weight_on_input,
expert_map)):
assert quant_config is not None
return run_cutlass_block_scaled_fused_experts(
a=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids)
else:
......@@ -1478,26 +1275,49 @@ def fused_experts(hidden_states: torch.Tensor,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
is_act_and_mul=is_act_and_mul,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_channel_quant=per_channel_quant,
use_fp8_w8a8=quant_config.use_fp8_w8a8,
use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16,
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
w1_bias=w1_bias,
w2_bias=w2_bias,
)
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
w1_zp=quant_config.w1_zp,
w2_zp=quant_config.w2_zp,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
block_shape=quant_config.block_shape,
w1_bias=quant_config.w1_bias,
w2_bias=quant_config.w2_bias)
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
def _get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_mxfp4_w4a4: bool,
) -> Union[None, torch.dtype, str]:
"""
Get the quantization type based on the quantization strategy flags.
We don't have a quant_config at this point so we need to work backwards.
A return type of None means no quantization is required because the
input is unquantized or has been quantized prior to calling
fused_experts_impl.
"""
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
elif use_mxfp4_w4a4:
return "mxfp4"
return None
def fused_experts_impl(
......@@ -1508,7 +1328,6 @@ def fused_experts_impl(
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
......@@ -1557,17 +1376,18 @@ def fused_experts_impl(
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
dtype=hidden_states.dtype)
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4)
config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
dtype=hidden_states.dtype)
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
# quantized prior to calling fused_experts.
quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_mxfp4_w4a4=use_mxfp4_w4a4)
get_config_func = functools.partial(
try_get_optimal_moe_config,
......@@ -1640,7 +1460,7 @@ def fused_experts_impl(
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
A_scale=a1_scale,
quant_dtype=qtype,
quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape)
......@@ -1671,30 +1491,29 @@ def fused_experts_impl(
B_bias=w1_bias)
# Activation function with multiplication
if activation == "silu" and is_act_and_mul:
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "gelu" and is_act_and_mul:
elif activation == "gelu":
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "swigluoai" and is_act_and_mul:
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
# Activation function without multiplication
elif activation == "silu":
elif activation == SILU_NO_MUL:
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == "gelu":
elif activation == GELU_NO_MUL:
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
f"with is_act_and_mul={is_act_and_mul}.")
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
A_scale=a2_scale,
quant_dtype=qtype,
quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape)
......@@ -1726,164 +1545,13 @@ def fused_experts_impl(
return out_hidden_states
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- is_act_and_mul (bool): If True, use activation-and-mul function for
activation (self-gated activation), otherwise use activation function
for activation (ungated activation).
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and
OCP MXFP4 activation to compute the inner products for w1 and w2.
Defaults to False.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[list[int]]): Optional block size for block-wise
quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if not is_act_and_mul:
assert inplace is False, (
"is_act_and_mul=False is not supported with inplace=True")
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
is_act_and_mul=is_act_and_mul,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
w1_bias=w1_bias,
w2_bias=w2_bias)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
super().__init__(quant_config)
@property
def activation_formats(
......@@ -1929,19 +1597,14 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.use_int4_w4a16:
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
......@@ -1964,17 +1627,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
if global_num_experts == -1:
global_num_experts = E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
config_dtype,
self.quant_config.config_name(hidden_states.dtype),
num_tokens,
block_shape=self.block_shape,
)
......@@ -2008,8 +1665,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
self.w1_scale,
self.w1_zp,
None, # topk_weights
sorted_token_ids,
expert_ids,
......@@ -2018,13 +1675,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=None # TODO support B_bias
B_bias=self.w1_bias,
)
self.activation(activation, intermediate_cache2,
......@@ -2033,7 +1690,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2q_scale: Optional[torch.Tensor] = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2, a2_scale, self.quant_dtype,
intermediate_cache2, self.a2_scale, self.quant_dtype,
self.per_act_token_quant, self.block_shape)
invoke_fused_moe_kernel(
......@@ -2041,8 +1698,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
self.w2_scale,
self.w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
......@@ -2051,36 +1708,21 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=None # TODO support B_bias
B_bias=self.w2_bias,
)
ops.moe_sum(intermediate_cache3, output)
def modular_triton_fused_moe(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_mxfp4_w4a4: bool,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> mk.FusedMoEModularKernel:
quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
TritonExperts(quant_config),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.utils import has_triton_kernels
......@@ -23,9 +25,6 @@ if has_triton_kernels():
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible.")
if TYPE_CHECKING:
from triton_kernels.matmul_ogs import PrecisionConfig
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
......@@ -35,20 +34,10 @@ def triton_kernel_moe_forward(
topk: int,
renormalize: bool,
activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
routing_data, gather_idx, scatter_idx = routing(gating_output,
......@@ -64,20 +53,10 @@ def triton_kernel_moe_forward(
gather_idx,
scatter_idx,
activation=activation,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_precision=w1_precision,
w2_precision=w2_precision,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
expert_map=expert_map)
# This is a triton implementation of the fused_experts function
......@@ -90,28 +69,23 @@ def triton_kernel_fused_experts(
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
a1q_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
assert w1_bias is None or w1_bias.dtype == torch.float32
assert w2_bias is None or w2_bias.dtype == torch.float32
assert (quant_config.w1_bias is None
or quant_config.w1_bias.dtype == torch.float32)
assert (quant_config.w2_bias is None
or quant_config.w2_bias.dtype == torch.float32)
# Shape check, only check non-mxfp4
assert hidden_states.shape[-1] == w1.shape[-2]
......@@ -130,20 +104,20 @@ def triton_kernel_fused_experts(
intermediate_cache1 = matmul_ogs(
hidden_states,
w1,
w1_bias,
quant_config.w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=w1_precision,
precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act)
intermediate_cache3 = matmul_ogs(
intermediate_cache1,
w2,
w2_bias,
quant_config.w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_precision,
precision_config=quant_config.w2_precision,
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
......@@ -154,21 +128,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
quant_config,
max_num_tokens: int,
num_dispatchers: int,
w1_precision: "PrecisionConfig",
w2_precision: "PrecisionConfig",
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
self.w1_precision = w1_precision
self.w2_precision = w2_precision
self.w1_bias = w1_bias
self.w2_bias = w2_bias
@property
def activation_formats(
......@@ -212,12 +178,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -228,20 +189,12 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states,
w1,
w2,
None,
None,
None,
routing_data=None,
gather_indx=None,
scatter_indx=None,
activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False,
use_fp8_w8a8=False,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=self.w1_bias,
w2_bias=self.w2_bias,
w1_precision=self.w1_precision,
w2_precision=self.w2_precision,
a1_scale=a1q_scale,
a2_scale=a2_scale)
a1q_scale=a1q_scale)
......@@ -22,7 +22,8 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, biased_moe_quant_config)
# yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel,
......@@ -78,11 +79,11 @@ class FusedMoeWeightScaleSupported(Enum):
class FusedMoEMethodBase(QuantizeMethodBase):
# TODO(bnell): also pass quant_config?
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe = moe
self.fused_experts: Optional[Callable] = None
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
@abstractmethod
......@@ -103,23 +104,28 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@staticmethod
def _maybe_make_prepare_finalize(
moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
moe: FusedMoEConfig,
quant_config: Optional[FusedMoEQuantConfig],
) -> Optional[FusedMoEPrepareAndFinalize]:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
# TODO: could allow this now
assert not moe.use_flashinfer_cutlass_kernels, \
"Must be created in modelopt.py"
if moe.use_pplx_kernels:
assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
moe.hidden_dim,
moe.in_dtype,
moe.quant_dtype,
per_act_token_quant=moe.per_act_token_quant,
block_shape=moe.block_shape,
quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
all_to_all_args = dict(
......@@ -165,6 +171,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
)
elif moe.use_deepep_ll_kernels:
assert quant_config is not None
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
......@@ -174,13 +181,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
# Note : We may want to use FP8 dispatch even otherwise just to
# reduce datamovement
use_fp8_dispatch = (moe.quant_config is not None
and moe.quant_config.quant_dtype
== current_platform.fp8_dtype()
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = (
quant_config.quant_dtype == current_platform.fp8_dtype()
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE)
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
......@@ -192,11 +197,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return prepare_finalize
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[FusedMoEPrepareAndFinalize]:
if moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
self) -> Optional[FusedMoEPrepareAndFinalize]:
if self.moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(
self.moe, self.moe_quant_config)
else:
return None
......@@ -204,7 +208,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# prepare_communication_buffer_for_model.
def init_prepare_finalize(self, layer: torch.nn.Module):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
# We must get the quant config here so that the layer is
# completely initialized, i.e. all weights loaded and post
# processed.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
prepare_finalize = self.maybe_make_prepare_finalize()
if prepare_finalize is not None:
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
......@@ -213,7 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
assert self.fused_experts is None, \
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
experts = self.select_gemm_impl(prepare_finalize, layer)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
......@@ -223,7 +233,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
......@@ -232,6 +241,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize")
@abstractmethod
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
raise NotImplementedError
@abstractmethod
def apply(
self,
......@@ -265,7 +279,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
......@@ -273,23 +286,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
def maybe_make_prepare_finalize(
self) -> Optional[FusedMoEPrepareAndFinalize]:
if self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts(
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
return TritonExperts()
return TritonExperts(self.moe_quant_config)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
......@@ -303,7 +323,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.has_bias:
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
......@@ -320,7 +340,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.has_bias:
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=params_dtype),
......@@ -442,6 +462,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count=logical_replica_count,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def forward_cuda(
self,
layer: torch.nn.Module,
......@@ -486,6 +516,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count=logical_replica_count)
if self.rocm_aiter_moe_enabled:
assert self.fused_experts is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
......@@ -496,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
elif self.fused_experts is not None:
if self.has_bias:
if self.moe.has_bias:
raise ValueError(
"FusedMoEModularKernel does not support bias.")
return self.fused_experts(
......@@ -517,12 +548,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_bias=layer.w13_bias if self.has_bias else None,
w2_bias=layer.w2_bias if self.has_bias else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
......@@ -933,16 +963,18 @@ class FusedMoE(CustomOp):
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig.make(num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config=quant_config,
has_bias=has_bias)
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
)
self.moe_config = moe
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
......@@ -990,6 +1022,9 @@ class FusedMoE(CustomOp):
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
# TODO(bnell): flashinfer uses non-batched format.
# Does it really need a batched buffer?
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.use_flashinfer_cutlass_kernels):
......@@ -1062,7 +1097,9 @@ class FusedMoE(CustomOp):
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_config.use_flashinfer_cutlass_kernels
return (self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
and self.moe_config.use_flashinfer_cutlass_kernels)
def update_expert_map(self):
# ep_size and ep_rank should already be updated
......@@ -1492,6 +1529,11 @@ class FusedMoE(CustomOp):
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx]
def ensure_moe_quant_config(self):
if self.quant_method.moe_quant_config is None:
self.quant_method.moe_quant_config = (
self.quant_method.get_fused_moe_quant_config(self))
@staticmethod
def select_experts(
hidden_states: torch.Tensor,
......@@ -1711,6 +1753,8 @@ class FusedMoE(CustomOp):
assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
self.ensure_moe_quant_config()
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(
......@@ -1825,14 +1869,17 @@ class FusedMoE(CustomOp):
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.quant_method is not None
self.ensure_moe_quant_config()
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
use_flashinfer_cutlass_kernels = (
self.dp_size > 1
and self.moe_config.use_flashinfer_cutlass_kernels)
_use_flashinfer_cutlass_kernels = (self.dp_size > 1 and
self.use_flashinfer_cutlass_kernels)
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or use_flashinfer_cutlass_kernels):
or _use_flashinfer_cutlass_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment