Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
......@@ -9,6 +9,8 @@ import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
......@@ -143,10 +145,16 @@ def pplx_cutlass_moe(
device="cuda",
dtype=torch.int64)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, per_act_token, per_out_ch,
ab_strides1, ab_strides2, c_strides1,
c_strides2)
experts = CutlassBatchedExpertsFp8(
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
ab_strides2, c_strides1, c_strides2,
fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank]))
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
......@@ -167,10 +175,7 @@ def pplx_cutlass_moe(
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank])
)
torch.cuda.synchronize()
......
......@@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
]
PPLX_COMBOS = [
# TODO: figure out why this fails, seems to be test problem
# TODO(bnell): figure out why this fails, seems to be test problem
#(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
......@@ -360,18 +360,18 @@ def pplx_prepare_finalize(
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
a_chunk,
a1_scale,
a2_scale,
chunk_topk_weight,
chunk_topk_ids,
num_experts,
None,
False,
FusedMoEQuantConfig(
FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant,
False,
block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=block_shape,
a1_scale=a1_scale,
a2_scale=a2_scale,
),
)
......@@ -540,20 +540,6 @@ def pplx_moe(
topk_ids = topk_ids.to(dtype=torch.uint32)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
......@@ -567,6 +553,28 @@ def pplx_moe(
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=quant_config,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
......@@ -585,10 +593,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
if use_cudagraphs:
......@@ -605,10 +609,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
torch.cuda.synchronize()
......@@ -820,7 +820,7 @@ def test_pplx_moe_slow(
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
......@@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
args["w1"] = w1
args["w2"] = w2
......
......@@ -7,10 +7,12 @@ import itertools
import pytest
import torch
from tests.kernels.moe.utils import fused_moe
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (9, 0):
......@@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
score,
topk,
renormalize=False,
use_fp8_w8a8=True, # using fp8
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
quant_config=fp8_w8a8_moe_quant_config(
per_act_token_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
),
)
# Check results
......
......@@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
......@@ -34,18 +35,22 @@ def triton_moe(
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_act_token_quant,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape)
quant_config=quant_config)
def batched_moe(
......@@ -64,6 +69,16 @@ def batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
......@@ -72,21 +87,11 @@ def batched_moe(
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
quant_config=quant_config,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def naive_batched_moe(
......@@ -105,6 +110,16 @@ def naive_batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
......@@ -113,21 +128,11 @@ def naive_batched_moe(
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
quant_config=quant_config,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
......@@ -216,7 +221,7 @@ def make_test_weight(
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
......@@ -228,7 +233,7 @@ def make_test_weight(
w_gs_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
......@@ -258,16 +263,16 @@ def make_test_weights(
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
return (
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
per_out_ch_quant),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
per_out_ch_quant),
)
......@@ -285,6 +290,76 @@ def per_token_cast_to_fp8(
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def make_test_quant_config(
e: int,
n: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype,
quant_dtype,
per_out_ch_quant=per_act_token_quant,
block_shape=block_shape,
)
# Hacky/trivial scales for nvfp4.
a1_gscale: Optional[torch.Tensor] = None
a2_gscale: Optional[torch.Tensor] = None
if quant_dtype == "nvfp4":
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a1_scale = a1_gscale
a2_scale = a2_gscale
else:
a1_scale = None
a2_scale = None
return w1, w2, FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_s,
w2_scale=w2_s,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
a1_scale=a1_scale,
a2_scale=a2_scale,
# TODO: make sure this is handled properly
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
renormalize: bool = False,
quant_config: Optional[FusedMoEQuantConfig] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=quant_config)
# CustomOp?
class BaselineMM(torch.nn.Module):
......
......@@ -8,7 +8,8 @@ import pytest
import torch
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8)
from vllm.platforms import current_platform
......@@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
topk_ids):
"""This function performs fused moe with per-column int8 quantization
using native torch."""
......@@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
......@@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
topk_weights, topk_ids)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
out = fused_moe(
quant_config = FusedMoEQuantConfig.make(
torch.int8,
per_act_token_quant=True,
block_shape=None,
w1_scale=w1_s,
w2_scale=w2_s,
)
out = fused_experts(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True, # Using int8-w8a8
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
topk_weights,
topk_ids,
quant_config=quant_config,
)
# Check results
......
......@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
......@@ -36,6 +37,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"activation_without_mul",
"override_config",
"get_config",
]
......@@ -43,7 +45,6 @@ __all__ = [
if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
......@@ -56,13 +57,12 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts, fused_experts, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
TritonExperts, fused_experts, fused_topk, get_config_file_name,
grouped_topk)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
__all__ += [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
......
......@@ -8,6 +8,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
......@@ -212,27 +214,20 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# The Deep Gemm kernels only support block size of 128
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
def __init__(self,
max_num_tokens: int,
num_dispatchers: int,
block_shape: list[int],
per_act_token_quant=False):
def __init__(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
num_dispatchers: The number of DP dispatchers.
block_shape: Block quantization block shape.
per_act_token_quant: Per activation token quantization flag.
quant_config: Quantization configuration
"""
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
super().__init__(quant_config)
assert self.block_shape == deep_gemm_block_shape()
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
......@@ -290,12 +285,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -321,11 +311,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale),
workspace1, expert_num_tokens, expected_m)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
workspace1, expert_num_tokens)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
expert_num_tokens, expected_m)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale),
output, expert_num_tokens, expected_m)
......@@ -8,55 +8,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self,
max_num_tokens: int,
num_dispatchers: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
allow_deep_gemm: bool = False):
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
))
def __init__(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
allow_deep_gemm: bool = False,
):
super().__init__(quant_config)
self.batched_triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape,
quant_config=self.quant_config,
)
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8
and self.block_shape
== BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE)
self.allow_deep_gemm = (allow_deep_gemm
and self.quant_config.use_fp8_w8a8 and
self.block_shape == deep_gemm_block_shape())
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
block_shape=self.block_shape, # type: ignore[arg-type]
quant_config=self.quant_config,
) if self.allow_deep_gemm else None
assert (self.batched_deep_gemm_experts is not None
......@@ -143,12 +125,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -158,7 +135,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta,
activation, global_num_experts, expert_map, a1q_scale,
workspace13, workspace2, expert_tokens_meta,
apply_router_weight_on_input)
......@@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
assert quant_config.use_fp8_w8a8
super().__init__(quant_config)
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
......@@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
......@@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
self.c_strides1, self.c_strides2, workspace13, workspace2,
expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format, topk_weights)
......@@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
quant_config,
)
@property
......@@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
quant_config,
)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
......@@ -414,16 +395,12 @@ def cutlass_moe_fp8(
w2_q: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
......@@ -475,10 +452,18 @@ def cutlass_moe_fp8(
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
if per_act_token is None:
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.size(0)
assert quant_config is not None
if quant_config.a1_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a1_scale.numel() != 1)
if quant_config.a2_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a2_scale.numel() != 1)
assert (quant_config.w1_scale is None
or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1)
== w1_q.size(1))))
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
0)
......@@ -487,12 +472,11 @@ def cutlass_moe_fp8(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
),
)
......@@ -502,14 +486,9 @@ def cutlass_moe_fp8(
w2_q,
topk_weights,
topk_ids,
False,
activation,
num_experts,
expert_map,
w1_scale,
w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
activation=activation,
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -542,7 +521,7 @@ def run_cutlass_moe_fp4(
) -> None:
"""
MoE implementation for FP4 Inputs
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
......@@ -552,16 +531,16 @@ def run_cutlass_moe_fp4(
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8
m, n, k: Unquantized weight shapes, dtype: int
e: number of experts, dtype: int
......@@ -652,42 +631,21 @@ def run_cutlass_moe_fp4(
return
# Split into batched and non-batched
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
use_batched_format: bool = False,
):
super().__init__(
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
FusedMoEQuantConfig(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
super().__init__(quant_config)
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
# TODO(bnell): put this stuff into quant config?
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
@property
def activation_formats(
self
......@@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor,
a1q_scale: Optional[torch.Tensor], # unused
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
a=hidden_states,
a1_gscale=self.a1_gscale,
w1_fp4=w1,
w1_blockscale=w1_scale,
w1_blockscale=self.w1_scale,
w1_alphas=self.g1_alphas,
a2_gscale=self.a2_gscale,
w2_fp4=w2,
w2_blockscale=w2_scale,
w2_blockscale=self.w2_scale,
w2_alphas=self.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
......@@ -788,14 +741,9 @@ def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int,
n: int,
k: int,
......@@ -805,17 +753,31 @@ def cutlass_moe_fp4(
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
# TODO(bnell): this feels a bit hacky
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
quant_config = FusedMoEQuantConfig.make(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=quant_config.per_act_token_quant,
per_out_ch_quant=quant_config.per_out_ch_quant,
block_shape=quant_config.block_shape,
g1_alphas=quant_config.g1_alphas,
g2_alphas=quant_config.g2_alphas,
a1_gscale=quant_config.a1_gscale,
a2_gscale=quant_config.a2_gscale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
g1_alphas,
g2_alphas,
a1_gscale,
a2_gscale,
max_experts_per_worker=e,
out_dtype=a.dtype,
per_act_token_quant=False,
per_out_ch_quant=False,
quant_config=quant_config,
use_batched_format=False,
),
)
......@@ -830,10 +792,6 @@ def cutlass_moe_fp4(
activation="silu",
global_num_experts=e,
expert_map=None,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
return True
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
def run_cutlass_block_scaled_fused_experts(
a: torch.Tensor,
w1: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Optional
import torch
......@@ -9,9 +8,11 @@ from tqdm import tqdm
import vllm.envs as env
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce)
compute_aligned_M, deep_gemm_block_shape, deepgemm_moe_permute,
deepgemm_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
......@@ -25,14 +26,6 @@ from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
logger = init_logger(__name__)
@functools.cache
def deep_gemm_block_shape() -> list[int]:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
block = dg.get_m_alignment_for_contiguous_layout()
return [block, block]
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
align = deep_gemm_block_shape()[0]
return align <= M and N % align == 0 and K % align == 0
......@@ -163,13 +156,12 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=deep_gemm_block_shape(),
))
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
assert quant_config.block_shape == deep_gemm_block_shape()
assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
@property
def activation_formats(
......@@ -221,21 +213,17 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert self.block_shape is not None
assert a1q_scale is not None
assert w1_scale is not None
assert w2_scale is not None
assert self.a2_scale is None
assert self.block_shape is not None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states
_, N, K = w1.size()
......@@ -270,7 +258,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
aq_out=a1q_perm)
assert a1q.size(0) == M_sum
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, self.w1_scale),
mm1_out, expert_ids)
self.activation(activation, act_out, mm1_out.view(-1, N))
......@@ -281,7 +269,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
column_major_scales=True,
out_q=quant_out)
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, self.w2_scale),
mm2_out, expert_ids)
if apply_router_weight_on_input:
......@@ -348,9 +336,16 @@ def deep_gemm_moe_fp8(
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=deep_gemm_block_shape())
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
DeepGemmExperts(),
DeepGemmExperts(quant_config),
)
return fn(
hidden_states,
......@@ -358,13 +353,9 @@ def deep_gemm_moe_fp8(
w2,
topk_weights,
topk_ids,
inplace,
activation,
global_num_experts,
expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -183,8 +183,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -204,7 +202,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_scale,
quant_config.a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
......@@ -215,7 +213,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else:
a1q = a1
a1q_scale = None
a1_post_scale = a1_scale
a1_post_scale = quant_config.a1_scale
return (lambda *args: None,
self._do_dispatch(tokens=a1q,
......@@ -229,8 +227,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -238,9 +234,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
(_, receiver) = self.prepare_async(a1, a1_scale, a2_scale,
topk_weights, topk_ids, num_experts,
expert_map,
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()
......
......@@ -77,15 +77,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def _do_quant(
self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_k = block_shape[1] if block_shape is not None else None
if self.use_fp8_dispatch:
block_k = quant_config.block_shape[
1] if quant_config.block_shape is not None else None
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
# DeepEP kernels did the quantization for us.
x, x_scales = x
......@@ -101,12 +99,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
per_act_token_quant,
block_shape)
x, x_scales = moe_kernel_quantize_input(
x, quant_config.a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
x = x.view((num_experts, -1, hidden_dim))
if quant_dtype is not None:
if quant_config.quant_dtype is not None:
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
......@@ -118,8 +116,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -139,9 +135,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert hidden_size % 128 == 0, \
"DeepEP kernels quantize the inputs in blocks of shape 128"
has_per_token_scales = a1_scale.numel(
) != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
has_per_token_scales = quant_config.a1_scale.numel(
) != 1 if quant_config.a1_scale is not None else (
quant_config.a2_scale.numel() != 1
if quant_config.a2_scale is not None else False)
assert not has_per_token_scales, (
"low_latency kernels doesn't support dispatching per-token scales")
......@@ -163,20 +160,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook=True)
self.handles[a2a_idx] = handle
return (hook, lambda: self._receiver(expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config))
return (
hook,
lambda: self._receiver(expert_x, expert_num_tokens, quant_config.
a1_scale, a1.dtype, quant_config))
def _receiver(
self,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor,
a1_scale,
a1_dtype,
a1_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype,
quant_config)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
......@@ -186,8 +184,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -195,8 +191,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(a1, a1_scale, a2_scale,
topk_weights, topk_ids,
hook, receiver = self.prepare_async(a1, topk_weights, topk_ids,
num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import Optional
import torch
......@@ -44,33 +44,20 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
out_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
quant_config: FusedMoEQuantConfig,
ep_rank: int = 0,
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=quant_dtype,
per_act_token_quant=False,
block_shape=None,
))
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
"Only nvfp4,fp8 quantization are currently supported.")
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
self.tp_size = tp_size
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
self.out_dtype = out_dtype
@property
......@@ -141,12 +128,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], # Not used
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -162,17 +144,17 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
fc2_expert_weights = w2
else:
# Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, (
assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not "
"be None for FlashInferExperts")
# Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale.
quant_scales = [
self.a1_gscale,
w1_scale.view(torch.int32),
self.w1_scale.view(torch.int32),
self.g1_alphas,
self.a2_gscale,
w2_scale.view(torch.int32),
self.w2_scale.view(torch.int32),
self.g2_alphas,
]
# FlashInfer API requires weight to be long for nvfp4
......@@ -202,12 +184,7 @@ def flashinfer_cutlass_moe_fp4(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
......@@ -216,15 +193,10 @@ def flashinfer_cutlass_moe_fp4(
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False,
a1_gscale=a1_gscale),
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False),
FlashInferExperts(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=hidden_states.dtype,
quant_dtype="nvfp4",
quant_config=quant_config,
))
return fused_experts(
......@@ -237,7 +209,5 @@ def flashinfer_cutlass_moe_fp4(
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -22,13 +22,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
use_dp: bool,
a1_gscale: Optional[torch.Tensor],
num_dispatchers: int = 1,
):
super().__init__()
self.num_dispatchers_ = num_dispatchers
self.use_dp = use_dp
self.a1_gscale = a1_gscale
self.local_tokens = None
@property
......@@ -47,14 +45,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], # Not used
a2_scale: Optional[torch.Tensor], # Not used
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
......@@ -67,7 +62,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
self.a1_gscale,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List # noqa: UP035
from typing import Optional
import torch
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import direct_register_custom_op
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: List[int], #noqa: UP006
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0) -> torch.Tensor:
return torch.empty_like(x)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)
from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
top_k, num_experts),
routing_method_type=routing_method_type)
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
return torch.empty_like(hidden_states)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
......@@ -8,7 +8,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
from vllm.model_executor.layers.fused_moe.utils import (
......@@ -498,8 +498,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
......@@ -545,14 +543,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dtype=torch.float32,
device=a1.device)
else:
assert a1_scale is None
assert quant_config.a1_scale is None
b_a1_scale = None
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
a1_scale = normalize_scales_shape(a1_scale)
a2_scale = normalize_scales_shape(a2_scale)
a1_scale = normalize_scales_shape(quant_config.a1_scale)
for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
......@@ -623,28 +620,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self,
max_num_tokens: int,
num_dispatchers: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
super().__init__(quant_config)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
......@@ -705,12 +687,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -740,10 +717,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
tmp = _resize_cache(workspace2, (num, N))
if self.quant_config.is_quantized:
assert a1q_scale is not None and w1_scale is not None
assert a1q_scale is not None and self.w1_scale is not None
input = self.dequant(hidden_states[expert, :, :],
a1q_scale[expert])
w1_dq = self.dequant(w1[expert], w1_scale[expert])
w1_dq = self.dequant(w1[expert], self.w1_scale[expert])
input = input[:num] @ w1_dq.transpose(0, 1)
else:
input = hidden_states[expert, :num, :] @ w1[expert].transpose(
......@@ -752,8 +729,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, tmp, input.to(tmp.dtype))
if self.quant_config.is_quantized:
assert w2_scale is not None
w2_dq = self.dequant(w2[expert], w2_scale[expert])
assert self.w2_scale is not None
w2_dq = self.dequant(w2[expert], self.w2_scale[expert])
else:
w2_dq = w2[expert]
......@@ -840,35 +817,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self,
max_num_tokens: int,
num_dispatchers: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
super().__init__(quant_config)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
assert max_num_tokens > 0
assert num_dispatchers > 0
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int8_w8a8 = use_int8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
......@@ -921,19 +878,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.use_int4_w4a16:
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
......@@ -958,11 +910,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w1.size(0) == E
assert w2.size(0) == E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=hidden_states.dtype)
config_dtype = self.quant_config.config_name(hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.size(),
......@@ -992,7 +940,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache2 = _resize_cache(workspace2,
(E, max_num_tokens, N // 2))
if self.use_fp8_w8a8:
# TODO(bnell): should this be done for any quantized type?
if self.quant_config.use_fp8_w8a8:
intermediate_cache1.fill_(0)
a1q_scale = normalize_batched_scales_shape(a1q_scale, E)
......@@ -1005,11 +954,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a1q_scale,
B_scale=w1_scale,
B_zp=w1_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
B_scale=self.w1_scale,
B_zp=self.w1_zp,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
config=config,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape)
......@@ -1021,7 +970,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache1.view(-1, N))
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
intermediate_cache2, a2_scale, max_num_tokens, E, N,
intermediate_cache2, self.a2_scale, max_num_tokens, E, N,
expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
self.block_shape)
......@@ -1032,11 +981,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a2q_scale,
B_scale=w2_scale,
B_zp=w2_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
B_scale=self.w2_scale,
B_zp=self.w2_zp,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
config=config,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.utils import has_triton_kernels
......@@ -23,9 +25,6 @@ if has_triton_kernels():
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible.")
if TYPE_CHECKING:
from triton_kernels.matmul_ogs import PrecisionConfig
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
......@@ -35,20 +34,10 @@ def triton_kernel_moe_forward(
topk: int,
renormalize: bool,
activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
routing_data, gather_idx, scatter_idx = routing(gating_output,
......@@ -64,20 +53,10 @@ def triton_kernel_moe_forward(
gather_idx,
scatter_idx,
activation=activation,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_precision=w1_precision,
w2_precision=w2_precision,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
expert_map=expert_map)
# This is a triton implementation of the fused_experts function
......@@ -90,28 +69,23 @@ def triton_kernel_fused_experts(
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
a1q_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
assert w1_bias is None or w1_bias.dtype == torch.float32
assert w2_bias is None or w2_bias.dtype == torch.float32
assert (quant_config.w1_bias is None
or quant_config.w1_bias.dtype == torch.float32)
assert (quant_config.w2_bias is None
or quant_config.w2_bias.dtype == torch.float32)
# Shape check, only check non-mxfp4
assert hidden_states.shape[-1] == w1.shape[-2]
......@@ -130,20 +104,20 @@ def triton_kernel_fused_experts(
intermediate_cache1 = matmul_ogs(
hidden_states,
w1,
w1_bias,
quant_config.w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=w1_precision,
precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act)
intermediate_cache3 = matmul_ogs(
intermediate_cache1,
w2,
w2_bias,
quant_config.w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_precision,
precision_config=quant_config.w2_precision,
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
......@@ -154,21 +128,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
quant_config,
max_num_tokens: int,
num_dispatchers: int,
w1_precision: "PrecisionConfig",
w2_precision: "PrecisionConfig",
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
self.w1_precision = w1_precision
self.w2_precision = w2_precision
self.w1_bias = w1_bias
self.w2_bias = w2_bias
@property
def activation_formats(
......@@ -212,12 +178,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
......@@ -228,20 +189,12 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states,
w1,
w2,
None,
None,
None,
routing_data=None,
gather_indx=None,
scatter_indx=None,
activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False,
use_fp8_w8a8=False,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=self.w1_bias,
w2_bias=self.w2_bias,
w1_precision=self.w1_precision,
w2_precision=self.w2_precision,
a1_scale=a1q_scale,
a2_scale=a2_scale)
a1q_scale=a1q_scale)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment