Unverified Commit 5963b98b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent e6585ddb
...@@ -9,6 +9,8 @@ import torch ...@@ -9,6 +9,8 @@ import torch
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8) CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
...@@ -143,10 +145,16 @@ def pplx_cutlass_moe( ...@@ -143,10 +145,16 @@ def pplx_cutlass_moe(
device="cuda", device="cuda",
dtype=torch.int64) dtype=torch.int64)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, experts = CutlassBatchedExpertsFp8(
out_dtype, per_act_token, per_out_ch, num_local_experts, num_dispatchers, out_dtype, ab_strides1,
ab_strides1, ab_strides2, c_strides1, ab_strides2, c_strides1, c_strides2,
c_strides2) fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank]))
fused_cutlass_experts = FusedMoEModularKernel( fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
...@@ -167,10 +175,7 @@ def pplx_cutlass_moe( ...@@ -167,10 +175,7 @@ def pplx_cutlass_moe(
chunk_topk_ids, chunk_topk_ids,
global_num_experts=num_experts, global_num_experts=num_experts,
expert_map=None, #TODO expert_map=None, #TODO
w1_scale=chunk_by_rank(w1_scale, rank, world_size), )
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank])
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [ ...@@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
] ]
PPLX_COMBOS = [ PPLX_COMBOS = [
# TODO: figure out why this fails, seems to be test problem # TODO(bnell): figure out why this fails, seems to be test problem
#(1, 128, 128), #(1, 128, 128),
(2, 128, 512), (2, 128, 512),
(3, 1024, 2048), (3, 1024, 2048),
...@@ -360,18 +360,18 @@ def pplx_prepare_finalize( ...@@ -360,18 +360,18 @@ def pplx_prepare_finalize(
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare( b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
a_chunk, a_chunk,
a1_scale,
a2_scale,
chunk_topk_weight, chunk_topk_weight,
chunk_topk_ids, chunk_topk_ids,
num_experts, num_experts,
None, None,
False, False,
FusedMoEQuantConfig( FusedMoEQuantConfig.make(
quant_dtype, quant_dtype,
per_act_token_quant, per_act_token_quant=per_act_token_quant,
False, per_out_ch_quant=False,
block_shape, block_shape=block_shape,
a1_scale=a1_scale,
a2_scale=a2_scale,
), ),
) )
...@@ -540,20 +540,6 @@ def pplx_moe( ...@@ -540,20 +540,6 @@ def pplx_moe(
topk_ids = topk_ids.to(dtype=torch.uint32) topk_ids = topk_ids.to(dtype=torch.uint32)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: workers with the same dp_rank must use the exact same inputs. # Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size) a_chunk = chunk_by_rank(a, rank, world_size)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size) chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
...@@ -567,6 +553,28 @@ def pplx_moe( ...@@ -567,6 +553,28 @@ def pplx_moe(
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size) a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size) a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=quant_config,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: for now use_compile will error out if the problem size is # Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and # large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later. # setup code in case we are able to revisit this later.
...@@ -585,10 +593,6 @@ def pplx_moe( ...@@ -585,10 +593,6 @@ def pplx_moe(
w2_chunk, w2_chunk,
chunk_topk_weight, chunk_topk_weight,
chunk_topk_ids, chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts) global_num_experts=num_experts)
if use_cudagraphs: if use_cudagraphs:
...@@ -605,10 +609,6 @@ def pplx_moe( ...@@ -605,10 +609,6 @@ def pplx_moe(
w2_chunk, w2_chunk,
chunk_topk_weight, chunk_topk_weight,
chunk_topk_ids, chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts) global_num_experts=num_experts)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -820,7 +820,7 @@ def test_pplx_moe_slow( ...@@ -820,7 +820,7 @@ def test_pplx_moe_slow(
k, k,
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
block_shape=block_shape, block_shape=block_shape,
per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_act_token_quant,
) )
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e, parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
...@@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, ...@@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
k, k,
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
block_shape=block_shape, block_shape=block_shape,
per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_act_token_quant,
) )
args["w1"] = w1 args["w1"] = w1
args["w2"] = w2 args["w2"] = w2
......
...@@ -7,10 +7,12 @@ import itertools ...@@ -7,10 +7,12 @@ import itertools
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import fused_moe
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.get_device_capability() < (9, 0): if current_platform.get_device_capability() < (9, 0):
...@@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): ...@@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
score, score,
topk, topk,
renormalize=False, renormalize=False,
use_fp8_w8a8=True, # using fp8 quant_config=fp8_w8a8_moe_quant_config(
per_channel_quant=True, per_act_token_quant=True,
w1_scale=w1_s, w1_scale=w1_s,
w2_scale=w2_s, w2_scale=w2_s,
block_shape=None, # Not using block quantization block_shape=None, # Not using block quantization
),
) )
# Check results # Check results
......
...@@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8 ...@@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX) FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
...@@ -34,18 +35,22 @@ def triton_moe( ...@@ -34,18 +35,22 @@ def triton_moe(
per_act_token_quant=False, per_act_token_quant=False,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
return fused_experts(a, return fused_experts(a,
w1, w1,
w2, w2,
topk_weight, topk_weight,
topk_ids, topk_ids,
w1_scale=w1_scale, quant_config=quant_config)
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_act_token_quant,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape)
def batched_moe( def batched_moe(
...@@ -64,6 +69,16 @@ def batched_moe( ...@@ -64,6 +69,16 @@ def batched_moe(
) -> torch.Tensor: ) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64) max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens, BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1, num_dispatchers=1,
...@@ -72,21 +87,11 @@ def batched_moe( ...@@ -72,21 +87,11 @@ def batched_moe(
BatchedTritonExperts( BatchedTritonExperts(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_dispatchers=1, num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, quant_config=quant_config,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
), ),
) )
return fused_experts(a, return fused_experts(a, w1, w2, topk_weight, topk_ids)
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
def naive_batched_moe( def naive_batched_moe(
...@@ -105,6 +110,16 @@ def naive_batched_moe( ...@@ -105,6 +110,16 @@ def naive_batched_moe(
) -> torch.Tensor: ) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64) max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens, BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1, num_dispatchers=1,
...@@ -113,21 +128,11 @@ def naive_batched_moe( ...@@ -113,21 +128,11 @@ def naive_batched_moe(
NaiveBatchedExperts( NaiveBatchedExperts(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_dispatchers=1, num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, quant_config=quant_config,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
), ),
) )
return fused_experts(a, return fused_experts(a, w1, w2, topk_weight, topk_ids)
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
def chunk_scales(scales: Optional[torch.Tensor], start: int, def chunk_scales(scales: Optional[torch.Tensor], start: int,
...@@ -216,7 +221,7 @@ def make_test_weight( ...@@ -216,7 +221,7 @@ def make_test_weight(
in_dtype: torch.dtype = torch.bfloat16, in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None, quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False, per_out_ch_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
...@@ -228,7 +233,7 @@ def make_test_weight( ...@@ -228,7 +233,7 @@ def make_test_weight(
w_gs_l = [None] * e w_gs_l = [None] * e
for idx in range(e): for idx in range(e):
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights( w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
w = torch.stack(w_l) w = torch.stack(w_l)
w_s = torch.stack(w_s_l) w_s = torch.stack(w_s_l)
...@@ -258,16 +263,16 @@ def make_test_weights( ...@@ -258,16 +263,16 @@ def make_test_weights(
in_dtype: torch.dtype = torch.bfloat16, in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None, quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False, per_out_ch_quant: bool = False,
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]], Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]: Optional[torch.Tensor]]]:
return ( return (
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant), per_out_ch_quant),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant), per_out_ch_quant),
) )
...@@ -285,6 +290,76 @@ def per_token_cast_to_fp8( ...@@ -285,6 +290,76 @@ def per_token_cast_to_fp8(
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def make_test_quant_config(
e: int,
n: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype,
quant_dtype,
per_out_ch_quant=per_act_token_quant,
block_shape=block_shape,
)
# Hacky/trivial scales for nvfp4.
a1_gscale: Optional[torch.Tensor] = None
a2_gscale: Optional[torch.Tensor] = None
if quant_dtype == "nvfp4":
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a1_scale = a1_gscale
a2_scale = a2_gscale
else:
a1_scale = None
a2_scale = None
return w1, w2, FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_s,
w2_scale=w2_s,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
a1_scale=a1_scale,
a2_scale=a2_scale,
# TODO: make sure this is handled properly
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
renormalize: bool = False,
quant_config: Optional[FusedMoEQuantConfig] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=quant_config)
# CustomOp? # CustomOp?
class BaselineMM(torch.nn.Module): class BaselineMM(torch.nn.Module):
......
...@@ -8,7 +8,8 @@ import pytest ...@@ -8,7 +8,8 @@ import pytest
import torch import torch
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8) per_token_quant_int8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): ...@@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
return C.reshape(origin_C_shape).to(output_dtype) return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
topk_ids):
"""This function performs fused moe with per-column int8 quantization """This function performs fused moe with per-column int8 quantization
using native torch.""" using native torch."""
...@@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): ...@@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing # Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1) topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1) topk_ids = topk_ids.view(-1)
# Process each expert # Process each expert
...@@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): ...@@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
topk_weights, topk_ids)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) quant_config = FusedMoEQuantConfig.make(
out = fused_moe( torch.int8,
per_act_token_quant=True,
block_shape=None,
w1_scale=w1_s,
w2_scale=w2_s,
)
out = fused_experts(
a, a,
w1, w1,
w2, w2,
score, topk_weights,
topk, topk_ids,
renormalize=False, quant_config=quant_config,
use_int8_w8a8=True, # Using int8-w8a8
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
) )
# Check results # Check results
......
...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize) FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None _config: Optional[dict[str, Any]] = None
...@@ -36,6 +37,7 @@ __all__ = [ ...@@ -36,6 +37,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute", "FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat", "FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize", "FusedMoEPrepareAndFinalize",
"activation_without_mul",
"override_config", "override_config",
"get_config", "get_config",
] ]
...@@ -43,7 +45,6 @@ __all__ = [ ...@@ -43,7 +45,6 @@ __all__ = [
if HAS_TRITON: if HAS_TRITON:
# import to register the custom ops # import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts) BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
...@@ -56,13 +57,12 @@ if HAS_TRITON: ...@@ -56,13 +57,12 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts) BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts, fused_experts, fused_moe, fused_topk, TritonExperts, fused_experts, fused_topk, get_config_file_name,
get_config_file_name, grouped_topk) grouped_topk)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts) TritonOrDeepGemmExperts)
__all__ += [ __all__ += [
"fused_moe",
"fused_topk", "fused_topk",
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
......
...@@ -8,6 +8,8 @@ import torch ...@@ -8,6 +8,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
...@@ -212,27 +214,20 @@ def silu_mul_fp8_quant_deep_gemm_cuda( ...@@ -212,27 +214,20 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# The Deep Gemm kernels only support block size of 128
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
def __init__(self, def __init__(
self,
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
block_shape: list[int], quant_config: FusedMoEQuantConfig,
per_act_token_quant=False): ):
""" """
max_num_tokens: Maximum number of tokens from a DP Rank max_num_tokens: Maximum number of tokens from a DP Rank
num_dispatchers: The number of DP dispatchers. num_dispatchers: The number of DP dispatchers.
block_shape: Block quantization block shape. quant_config: Quantization configuration
per_act_token_quant: Per activation token quantization flag.
""" """
super().__init__( super().__init__(quant_config)
FusedMoEQuantConfig( assert self.block_shape == deep_gemm_block_shape()
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
...@@ -290,12 +285,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -290,12 +285,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
...@@ -321,11 +311,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -321,11 +311,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# for the M expectation of each batch, correctly setting this value # for the M expectation of each batch, correctly setting this value
# may lead to better performance. # may lead to better performance.
expected_m = max_num_tokens expected_m = max_num_tokens
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale),
workspace1, expert_num_tokens, expected_m) workspace1, expert_num_tokens, expected_m)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda( a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
workspace1, expert_num_tokens) workspace1, expert_num_tokens)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale),
expert_num_tokens, expected_m) output, expert_num_tokens, expected_m)
...@@ -8,55 +8,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -8,55 +8,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts) BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts) BatchedTritonExperts)
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, def __init__(
self,
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
use_fp8_w8a8: bool = False, quant_config: FusedMoEQuantConfig,
use_int8_w8a8: bool = False, allow_deep_gemm: bool = False,
use_int8_w8a16: bool = False, ):
use_int4_w4a16: bool = False, super().__init__(quant_config)
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
allow_deep_gemm: bool = False):
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
))
self.batched_triton_experts = BatchedTritonExperts( self.batched_triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers, num_dispatchers=num_dispatchers,
use_fp8_w8a8=use_fp8_w8a8, quant_config=self.quant_config,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape,
) )
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 self.allow_deep_gemm = (allow_deep_gemm
and self.block_shape and self.quant_config.use_fp8_w8a8 and
== BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) self.block_shape == deep_gemm_block_shape())
self.batched_deep_gemm_experts = BatchedDeepGemmExperts( self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers, num_dispatchers=num_dispatchers,
block_shape=self.block_shape, # type: ignore[arg-type] quant_config=self.quant_config,
) if self.allow_deep_gemm else None ) if self.allow_deep_gemm else None
assert (self.batched_deep_gemm_experts is not None assert (self.batched_deep_gemm_experts is not None
...@@ -143,12 +125,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -143,12 +125,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
...@@ -158,7 +135,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -158,7 +135,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm else self.batched_triton_experts) if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None assert experts is not None
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids, experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
activation, global_num_experts, expert_map, w1_scale, activation, global_num_experts, expert_map, a1q_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, workspace13, workspace2, expert_tokens_meta,
workspace2, expert_tokens_meta,
apply_router_weight_on_input) apply_router_weight_on_input)
...@@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor, ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor, ab_strides2: torch.Tensor,
c_strides1: torch.Tensor, c_strides1: torch.Tensor,
c_strides2: torch.Tensor, c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, quant_config: FusedMoEQuantConfig,
): ):
super().__init__( assert quant_config.use_fp8_w8a8
FusedMoEQuantConfig( super().__init__(quant_config)
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1 self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2 self.ab_strides2 = ab_strides2
...@@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): ):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None expert_num_tokens = None
if expert_tokens_meta is not None: if expert_tokens_meta is not None:
...@@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
in_dtype = hidden_states.dtype in_dtype = hidden_states.dtype
run_cutlass_moe_fp8( run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable, output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, global_num_experts, expert_map, self.w1_scale, self.w2_scale,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
self.c_strides2, workspace13, workspace2, expert_num_tokens, self.c_strides1, self.c_strides2, workspace13, workspace2,
expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype, self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant, self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format, topk_weights) use_batched_format, topk_weights)
...@@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__( def __init__(
self, self,
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor, ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor, ab_strides2: torch.Tensor,
c_strides1: torch.Tensor, c_strides1: torch.Tensor,
c_strides2: torch.Tensor, c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, quant_config: FusedMoEQuantConfig,
): ):
super().__init__( super().__init__(
out_dtype, out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1, ab_strides1,
ab_strides2, ab_strides2,
c_strides1, c_strides1,
c_strides2, c_strides2,
block_shape, quant_config,
) )
@property @property
...@@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
max_experts_per_worker: int, max_experts_per_worker: int,
num_dispatchers: int, num_dispatchers: int,
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor, ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor, ab_strides2: torch.Tensor,
c_strides1: torch.Tensor, c_strides1: torch.Tensor,
c_strides2: torch.Tensor, c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, quant_config: FusedMoEQuantConfig,
): ):
super().__init__( super().__init__(
out_dtype, out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1, ab_strides1,
ab_strides2, ab_strides2,
c_strides1, c_strides1,
c_strides2, c_strides2,
block_shape, quant_config,
) )
assert max_experts_per_worker > 0 assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker self.max_experts_per_worker = max_experts_per_worker
...@@ -414,16 +395,12 @@ def cutlass_moe_fp8( ...@@ -414,16 +395,12 @@ def cutlass_moe_fp8(
w2_q: torch.Tensor, w2_q: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor, ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor, ab_strides2: torch.Tensor,
c_strides1: torch.Tensor, c_strides1: torch.Tensor,
c_strides2: torch.Tensor, c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None, quant_config: FusedMoEQuantConfig,
activation: str = "silu", activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -475,10 +452,18 @@ def cutlass_moe_fp8( ...@@ -475,10 +452,18 @@ def cutlass_moe_fp8(
Returns: Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer. - torch.Tensor: The fp16 output tensor after applying the MoE layer.
""" """
if per_act_token is None: assert quant_config is not None
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False) if quant_config.a1_scale is not None:
per_out_ch = w1_scale.numel() != w1_q.size(0) assert (quant_config.per_act_token_quant ==
quant_config.a1_scale.numel() != 1)
if quant_config.a2_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a2_scale.numel() != 1)
assert (quant_config.w1_scale is None
or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1)
== w1_q.size(1))))
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
0) 0)
...@@ -487,12 +472,11 @@ def cutlass_moe_fp8( ...@@ -487,12 +472,11 @@ def cutlass_moe_fp8(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8( CutlassExpertsFp8(
out_dtype=a.dtype, out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1, ab_strides1=ab_strides1,
ab_strides2=ab_strides2, ab_strides2=ab_strides2,
c_strides1=c_strides1, c_strides1=c_strides1,
c_strides2=c_strides2, c_strides2=c_strides2,
quant_config=quant_config,
), ),
) )
...@@ -502,14 +486,9 @@ def cutlass_moe_fp8( ...@@ -502,14 +486,9 @@ def cutlass_moe_fp8(
w2_q, w2_q,
topk_weights, topk_weights,
topk_ids, topk_ids,
False, activation=activation,
activation, global_num_experts=num_experts,
num_experts, expert_map=expert_map,
expert_map,
w1_scale,
w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -652,42 +631,21 @@ def run_cutlass_moe_fp4( ...@@ -652,42 +631,21 @@ def run_cutlass_moe_fp4(
return return
# Split into batched and non-batched
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
max_experts_per_worker: int, max_experts_per_worker: int,
out_dtype: torch.dtype, out_dtype: torch.dtype,
per_act_token_quant: bool, quant_config: FusedMoEQuantConfig,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
use_batched_format: bool = False, use_batched_format: bool = False,
): ):
super().__init__( super().__init__(quant_config)
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
FusedMoEQuantConfig(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
self.max_experts_per_worker = max_experts_per_worker self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.use_batched_format = use_batched_format self.use_batched_format = use_batched_format
# TODO(bnell): put this stuff into quant config?
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
@property @property
def activation_formats( def activation_formats(
self self
...@@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: torch.Tensor, a1q_scale: Optional[torch.Tensor], # unused
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor,
workspace13: Optional[torch.Tensor], workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor], workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
...@@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
a=hidden_states, a=hidden_states,
a1_gscale=self.a1_gscale, a1_gscale=self.a1_gscale,
w1_fp4=w1, w1_fp4=w1,
w1_blockscale=w1_scale, w1_blockscale=self.w1_scale,
w1_alphas=self.g1_alphas, w1_alphas=self.g1_alphas,
a2_gscale=self.a2_gscale, a2_gscale=self.a2_gscale,
w2_fp4=w2, w2_fp4=w2,
w2_blockscale=w2_scale, w2_blockscale=self.w2_scale,
w2_alphas=self.g2_alphas, w2_alphas=self.g2_alphas,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
...@@ -788,14 +741,9 @@ def cutlass_moe_fp4( ...@@ -788,14 +741,9 @@ def cutlass_moe_fp4(
a: torch.Tensor, a: torch.Tensor,
w1_fp4: torch.Tensor, w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor, w2_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int, m: int,
n: int, n: int,
k: int, k: int,
...@@ -805,17 +753,31 @@ def cutlass_moe_fp4( ...@@ -805,17 +753,31 @@ def cutlass_moe_fp4(
assert expert_map is None, ("Expert Parallelism / expert_map " assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for " "is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.") "ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
# TODO(bnell): this feels a bit hacky
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
quant_config = FusedMoEQuantConfig.make(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=quant_config.per_act_token_quant,
per_out_ch_quant=quant_config.per_out_ch_quant,
block_shape=quant_config.block_shape,
g1_alphas=quant_config.g1_alphas,
g2_alphas=quant_config.g2_alphas,
a1_gscale=quant_config.a1_gscale,
a2_gscale=quant_config.a2_gscale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
)
fn = mk.FusedMoEModularKernel( fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4( CutlassExpertsFp4(
g1_alphas,
g2_alphas,
a1_gscale,
a2_gscale,
max_experts_per_worker=e, max_experts_per_worker=e,
out_dtype=a.dtype, out_dtype=a.dtype,
per_act_token_quant=False, quant_config=quant_config,
per_out_ch_quant=False,
use_batched_format=False, use_batched_format=False,
), ),
) )
...@@ -830,10 +792,6 @@ def cutlass_moe_fp4( ...@@ -830,10 +792,6 @@ def cutlass_moe_fp4(
activation="silu", activation="silu",
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm( ...@@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
return True return True
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
def run_cutlass_block_scaled_fused_experts( def run_cutlass_block_scaled_fused_experts(
a: torch.Tensor, a: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Optional from typing import Optional
import torch import torch
...@@ -9,9 +8,11 @@ from tqdm import tqdm ...@@ -9,9 +8,11 @@ from tqdm import tqdm
import vllm.envs as env import vllm.envs as env
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce) compute_aligned_M, deep_gemm_block_shape, deepgemm_moe_permute,
deepgemm_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
...@@ -25,14 +26,6 @@ from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous ...@@ -25,14 +26,6 @@ from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
logger = init_logger(__name__) logger = init_logger(__name__)
@functools.cache
def deep_gemm_block_shape() -> list[int]:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
block = dg.get_m_alignment_for_contiguous_layout()
return [block, block]
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool: def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
align = deep_gemm_block_shape()[0] align = deep_gemm_block_shape()[0]
return align <= M and N % align == 0 and K % align == 0 return align <= M and N % align == 0 and K % align == 0
...@@ -163,13 +156,12 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, ...@@ -163,13 +156,12 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self): def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__( super().__init__(quant_config)
FusedMoEQuantConfig( assert quant_config.block_shape == deep_gemm_block_shape()
quant_dtype=torch.float8_e4m3fn, assert quant_config.quant_dtype == torch.float8_e4m3fn
per_act_token_quant=False, assert not quant_config.per_act_token_quant
block_shape=deep_gemm_block_shape(), assert not quant_config.per_out_ch_quant
))
@property @property
def activation_formats( def activation_formats(
...@@ -221,21 +213,17 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -221,21 +213,17 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): ):
assert self.block_shape is not None
assert a1q_scale is not None assert a1q_scale is not None
assert w1_scale is not None assert self.a2_scale is None
assert w2_scale is not None assert self.block_shape is not None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states a1q = hidden_states
_, N, K = w1.size() _, N, K = w1.size()
...@@ -270,7 +258,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -270,7 +258,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
aq_out=a1q_perm) aq_out=a1q_perm)
assert a1q.size(0) == M_sum assert a1q.size(0) == M_sum
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, self.w1_scale),
mm1_out, expert_ids) mm1_out, expert_ids)
self.activation(activation, act_out, mm1_out.view(-1, N)) self.activation(activation, act_out, mm1_out.view(-1, N))
...@@ -281,7 +269,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -281,7 +269,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
column_major_scales=True, column_major_scales=True,
out_q=quant_out) out_q=quant_out)
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, self.w2_scale),
mm2_out, expert_ids) mm2_out, expert_ids)
if apply_router_weight_on_input: if apply_router_weight_on_input:
...@@ -348,9 +336,16 @@ def deep_gemm_moe_fp8( ...@@ -348,9 +336,16 @@ def deep_gemm_moe_fp8(
Returns: Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer. - torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
""" """
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=deep_gemm_block_shape())
fn = mk.FusedMoEModularKernel( fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
DeepGemmExperts(), DeepGemmExperts(quant_config),
) )
return fn( return fn(
hidden_states, hidden_states,
...@@ -358,13 +353,9 @@ def deep_gemm_moe_fp8( ...@@ -358,13 +353,9 @@ def deep_gemm_moe_fp8(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace, inplace=inplace,
activation, activation=activation,
global_num_experts, global_num_experts=global_num_experts,
expert_map, expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -183,8 +183,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -183,8 +183,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async( def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -204,7 +202,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -204,7 +202,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Quant and Dispatch # Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
a1_scale, quant_config.a1_scale,
quant_dtype=quant_config.quant_dtype, quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant, per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape, block_shape=quant_config.block_shape,
...@@ -215,7 +213,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -215,7 +213,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else: else:
a1q = a1 a1q = a1
a1q_scale = None a1q_scale = None
a1_post_scale = a1_scale a1_post_scale = quant_config.a1_scale
return (lambda *args: None, return (lambda *args: None,
self._do_dispatch(tokens=a1q, self._do_dispatch(tokens=a1q,
...@@ -229,8 +227,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -229,8 +227,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -238,9 +234,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -238,9 +234,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
(_, receiver) = self.prepare_async(a1, a1_scale, a2_scale, (_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
topk_weights, topk_ids, num_experts, num_experts, expert_map,
expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
quant_config) quant_config)
return receiver() return receiver()
......
...@@ -77,15 +77,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -77,15 +77,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def _do_quant( def _do_quant(
self, self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype, a1_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None], quant_config: FusedMoEQuantConfig,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_k = block_shape[1] if block_shape is not None else None
if self.use_fp8_dispatch: if self.use_fp8_dispatch:
block_k = quant_config.block_shape[
1] if quant_config.block_shape is not None else None
if block_k == DEEPEP_QUANT_BLOCK_SIZE: if block_k == DEEPEP_QUANT_BLOCK_SIZE:
# DeepEP kernels did the quantization for us. # DeepEP kernels did the quantization for us.
x, x_scales = x x, x_scales = x
...@@ -101,12 +99,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -101,12 +99,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun): Optimization - Use a batched version of quant # TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim)) x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype, x, x_scales = moe_kernel_quantize_input(
per_act_token_quant, x, quant_config.a1_scale, quant_config.quant_dtype,
block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
x = x.view((num_experts, -1, hidden_dim)) x = x.view((num_experts, -1, hidden_dim))
if quant_dtype is not None: if quant_config.quant_dtype is not None:
assert x_scales is not None assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts) x_scales = normalize_batched_scales_shape(x_scales, num_experts)
...@@ -118,8 +116,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -118,8 +116,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare_async( def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -139,9 +135,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -139,9 +135,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert hidden_size % 128 == 0, \ assert hidden_size % 128 == 0, \
"DeepEP kernels quantize the inputs in blocks of shape 128" "DeepEP kernels quantize the inputs in blocks of shape 128"
has_per_token_scales = a1_scale.numel( has_per_token_scales = quant_config.a1_scale.numel(
) != 1 if a1_scale is not None else ( ) != 1 if quant_config.a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False) quant_config.a2_scale.numel() != 1
if quant_config.a2_scale is not None else False)
assert not has_per_token_scales, ( assert not has_per_token_scales, (
"low_latency kernels doesn't support dispatching per-token scales") "low_latency kernels doesn't support dispatching per-token scales")
...@@ -163,20 +160,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -163,20 +160,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook=True) return_recv_hook=True)
self.handles[a2a_idx] = handle self.handles[a2a_idx] = handle
return (hook, lambda: self._receiver(expert_x, expert_num_tokens, return (
hook,
lambda: self._receiver(expert_x, expert_num_tokens, quant_config.
a1_scale, a1.dtype, quant_config)) a1_scale, a1.dtype, quant_config))
def _receiver( def _receiver(
self, self,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor, expert_num_tokens: torch.Tensor,
a1_scale, a1_scale: Optional[torch.Tensor],
a1_dtype, a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant( expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype,
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype, quant_config)
quant_config.per_act_token_quant, quant_config.block_shape)
expert_tokens_meta = mk.ExpertTokensMetadata( expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
...@@ -186,8 +184,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -186,8 +184,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -195,8 +191,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -195,8 +191,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(a1, a1_scale, a2_scale, hook, receiver = self.prepare_async(a1, topk_weights, topk_ids,
topk_weights, topk_ids,
num_experts, expert_map, num_experts, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
quant_config) quant_config)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union from typing import Optional
import torch import torch
...@@ -44,33 +44,20 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -44,33 +44,20 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None], quant_config: FusedMoEQuantConfig,
ep_rank: int = 0, ep_rank: int = 0,
ep_size: int = 1, ep_size: int = 1,
tp_rank: int = 0, tp_rank: int = 0,
tp_size: int = 1, tp_size: int = 1,
): ):
super().__init__( super().__init__(quant_config)
FusedMoEQuantConfig( assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
quant_dtype=quant_dtype,
per_act_token_quant=False,
block_shape=None,
))
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
"Only nvfp4,fp8 quantization are currently supported.") "Only nvfp4,fp8 quantization are currently supported.")
self.ep_rank = ep_rank self.ep_rank = ep_rank
self.ep_size = ep_size self.ep_size = ep_size
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
self.out_dtype = out_dtype self.out_dtype = out_dtype
@property @property
...@@ -141,12 +128,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -141,12 +128,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], # Not used
workspace13: Optional[torch.Tensor], workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor], workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
...@@ -162,17 +144,17 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -162,17 +144,17 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
fc2_expert_weights = w2 fc2_expert_weights = w2
else: else:
# Ensure w1_scale and w2_scale are not None before calling view # Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, ( assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not " "w1_scale and w2_scale must not "
"be None for FlashInferExperts") "be None for FlashInferExperts")
# Flashinfer CUTLASS kernel takes scalar global scales, # Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale. # min because inv_scale.
quant_scales = [ quant_scales = [
self.a1_gscale, self.a1_gscale,
w1_scale.view(torch.int32), self.w1_scale.view(torch.int32),
self.g1_alphas, self.g1_alphas,
self.a2_gscale, self.a2_gscale,
w2_scale.view(torch.int32), self.w2_scale.view(torch.int32),
self.g2_alphas, self.g2_alphas,
] ]
# FlashInfer API requires weight to be long for nvfp4 # FlashInfer API requires weight to be long for nvfp4
...@@ -202,12 +184,7 @@ def flashinfer_cutlass_moe_fp4( ...@@ -202,12 +184,7 @@ def flashinfer_cutlass_moe_fp4(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, quant_config: FusedMoEQuantConfig,
w2_scale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -216,15 +193,10 @@ def flashinfer_cutlass_moe_fp4( ...@@ -216,15 +193,10 @@ def flashinfer_cutlass_moe_fp4(
) -> torch.Tensor: ) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel( fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False, FlashInferCutlassMoEPrepareAndFinalize(use_dp=False),
a1_gscale=a1_gscale),
FlashInferExperts( FlashInferExperts(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=hidden_states.dtype, out_dtype=hidden_states.dtype,
quant_dtype="nvfp4", quant_config=quant_config,
)) ))
return fused_experts( return fused_experts(
...@@ -237,7 +209,5 @@ def flashinfer_cutlass_moe_fp4( ...@@ -237,7 +209,5 @@ def flashinfer_cutlass_moe_fp4(
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -22,13 +22,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -22,13 +22,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__( def __init__(
self, self,
use_dp: bool, use_dp: bool,
a1_gscale: Optional[torch.Tensor],
num_dispatchers: int = 1, num_dispatchers: int = 1,
): ):
super().__init__() super().__init__()
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
self.use_dp = use_dp self.use_dp = use_dp
self.a1_gscale = a1_gscale
self.local_tokens = None self.local_tokens = None
@property @property
...@@ -47,14 +45,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -47,14 +45,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], # Not used
a2_scale: Optional[torch.Tensor], # Not used
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
...@@ -67,7 +62,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -67,7 +62,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
self.a1_gscale, quant_config.a1_gscale,
quant_config.quant_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.per_act_token_quant,
quant_config.block_shape, quant_config.block_shape,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List # noqa: UP035
from typing import Optional
import torch
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import direct_register_custom_op
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: List[int], #noqa: UP006
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0) -> torch.Tensor:
return torch.empty_like(x)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)
from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
top_k, num_experts),
routing_method_type=routing_method_type)
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
return torch.empty_like(hidden_states)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config) try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched) TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
...@@ -498,8 +498,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -498,8 +498,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int, num_experts: int,
...@@ -545,14 +543,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -545,14 +543,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dtype=torch.float32, dtype=torch.float32,
device=a1.device) device=a1.device)
else: else:
assert a1_scale is None assert quant_config.a1_scale is None
b_a1_scale = None b_a1_scale = None
first_expert = num_local_experts * self.rank first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts last_expert = first_expert + num_local_experts
a1_scale = normalize_scales_shape(a1_scale) a1_scale = normalize_scales_shape(quant_config.a1_scale)
a2_scale = normalize_scales_shape(a2_scale)
for expert_id in range(first_expert, last_expert): for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten() topks = torch.any(topk_ids == expert_id, dim=1).flatten()
...@@ -623,28 +620,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -623,28 +620,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self, self,
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
use_fp8_w8a8: bool = False, quant_config: FusedMoEQuantConfig,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
): ):
super().__init__( super().__init__(quant_config)
FusedMoEQuantConfig.make( assert not self.quant_config.use_int8_w8a8, "NYI"
use_fp8_w8a8=use_fp8_w8a8, assert not self.quant_config.use_int8_w8a16, "NYI"
use_int8_w8a8=use_int8_w8a8, assert not self.quant_config.use_int4_w4a16, "NYI"
use_int8_w8a16=use_int8_w8a16, assert not self.quant_config.use_mxfp4_w4a4, "NYI"
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
...@@ -705,12 +687,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -705,12 +687,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
...@@ -740,10 +717,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -740,10 +717,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
tmp = _resize_cache(workspace2, (num, N)) tmp = _resize_cache(workspace2, (num, N))
if self.quant_config.is_quantized: if self.quant_config.is_quantized:
assert a1q_scale is not None and w1_scale is not None assert a1q_scale is not None and self.w1_scale is not None
input = self.dequant(hidden_states[expert, :, :], input = self.dequant(hidden_states[expert, :, :],
a1q_scale[expert]) a1q_scale[expert])
w1_dq = self.dequant(w1[expert], w1_scale[expert]) w1_dq = self.dequant(w1[expert], self.w1_scale[expert])
input = input[:num] @ w1_dq.transpose(0, 1) input = input[:num] @ w1_dq.transpose(0, 1)
else: else:
input = hidden_states[expert, :num, :] @ w1[expert].transpose( input = hidden_states[expert, :num, :] @ w1[expert].transpose(
...@@ -752,8 +729,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -752,8 +729,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, tmp, input.to(tmp.dtype)) self.activation(activation, tmp, input.to(tmp.dtype))
if self.quant_config.is_quantized: if self.quant_config.is_quantized:
assert w2_scale is not None assert self.w2_scale is not None
w2_dq = self.dequant(w2[expert], w2_scale[expert]) w2_dq = self.dequant(w2[expert], self.w2_scale[expert])
else: else:
w2_dq = w2[expert] w2_dq = w2[expert]
...@@ -840,35 +817,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -840,35 +817,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self, self,
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
use_fp8_w8a8: bool = False, quant_config: FusedMoEQuantConfig,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
): ):
super().__init__( super().__init__(quant_config)
FusedMoEQuantConfig.make( assert not self.quant_config.use_int8_w8a8, "NYI"
use_fp8_w8a8=use_fp8_w8a8, assert not self.quant_config.use_int8_w8a16, "NYI"
use_int8_w8a8=use_int8_w8a8, assert not self.quant_config.use_int4_w4a16, "NYI"
use_int8_w8a16=use_int8_w8a16, assert not self.quant_config.use_mxfp4_w4a4, "NYI"
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
assert not use_mxfp4_w4a4, "NYI"
assert max_num_tokens > 0 assert max_num_tokens > 0
assert num_dispatchers > 0 assert num_dispatchers > 0
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int8_w8a8 = use_int8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
...@@ -921,19 +878,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -921,19 +878,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): ):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), ( assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch") "Hidden size mismatch")
else: else:
...@@ -958,11 +910,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -958,11 +910,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w1.size(0) == E assert w1.size(0) == E
assert w2.size(0) == E assert w2.size(0) == E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, config_dtype = self.quant_config.config_name(hidden_states.dtype)
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config( config = try_get_optimal_moe_config(
w1.size(), w1.size(),
...@@ -992,7 +940,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -992,7 +940,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache2 = _resize_cache(workspace2, intermediate_cache2 = _resize_cache(workspace2,
(E, max_num_tokens, N // 2)) (E, max_num_tokens, N // 2))
if self.use_fp8_w8a8: # TODO(bnell): should this be done for any quantized type?
if self.quant_config.use_fp8_w8a8:
intermediate_cache1.fill_(0) intermediate_cache1.fill_(0)
a1q_scale = normalize_batched_scales_shape(a1q_scale, E) a1q_scale = normalize_batched_scales_shape(a1q_scale, E)
...@@ -1005,11 +954,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1005,11 +954,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens, expert_num_tokens=expert_num_tokens,
compute_type=compute_type, compute_type=compute_type,
A_scale=a1q_scale, A_scale=a1q_scale,
B_scale=w1_scale, B_scale=self.w1_scale,
B_zp=w1_zp, B_zp=self.w1_zp,
use_fp8_w8a8=self.use_fp8_w8a8, use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.quant_config.use_int4_w4a16,
config=config, config=config,
per_act_token_quant=self.per_act_token_quant, per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape) block_shape=self.block_shape)
...@@ -1021,7 +970,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1021,7 +970,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
intermediate_cache2, a2_scale, max_num_tokens, E, N, intermediate_cache2, self.a2_scale, max_num_tokens, E, N,
expert_num_tokens, self.quant_dtype, self.per_act_token_quant, expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
self.block_shape) self.block_shape)
...@@ -1032,11 +981,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1032,11 +981,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens, expert_num_tokens=expert_num_tokens,
compute_type=compute_type, compute_type=compute_type,
A_scale=a2q_scale, A_scale=a2q_scale,
B_scale=w2_scale, B_scale=self.w2_scale,
B_zp=w2_zp, B_zp=self.w2_zp,
use_fp8_w8a8=self.use_fp8_w8a8, use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.quant_config.use_int4_w4a16,
config=config, config=config,
per_act_token_quant=self.per_act_token_quant, per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape) block_shape=self.block_shape)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional from typing import Optional
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate)
from vllm.utils import has_triton_kernels from vllm.utils import has_triton_kernels
...@@ -23,9 +25,6 @@ if has_triton_kernels(): ...@@ -23,9 +25,6 @@ if has_triton_kernels():
"Failed to import Triton kernels. Please make sure your triton " "Failed to import Triton kernels. Please make sure your triton "
"version is compatible.") "version is compatible.")
if TYPE_CHECKING:
from triton_kernels.matmul_ogs import PrecisionConfig
def triton_kernel_moe_forward( def triton_kernel_moe_forward(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -35,20 +34,10 @@ def triton_kernel_moe_forward( ...@@ -35,20 +34,10 @@ def triton_kernel_moe_forward(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
activation: str = "silu", activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
routing_data, gather_idx, scatter_idx = routing(gating_output, routing_data, gather_idx, scatter_idx = routing(gating_output,
...@@ -64,20 +53,10 @@ def triton_kernel_moe_forward( ...@@ -64,20 +53,10 @@ def triton_kernel_moe_forward(
gather_idx, gather_idx,
scatter_idx, scatter_idx,
activation=activation, activation=activation,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map)
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_precision=w1_precision,
w2_precision=w2_precision,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
# This is a triton implementation of the fused_experts function # This is a triton implementation of the fused_experts function
...@@ -90,28 +69,23 @@ def triton_kernel_fused_experts( ...@@ -90,28 +69,23 @@ def triton_kernel_fused_experts(
gather_indx, # GatherIndx gather_indx, # GatherIndx
scatter_indx, # ScatterIndx scatter_indx, # ScatterIndx
activation: str = "silu", activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
swiglu_alpha: float = 1.702, swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0, swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, a1q_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
# type check, uint8 means mxfp4 # type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16 assert hidden_states.dtype == torch.bfloat16
assert w1_bias is None or w1_bias.dtype == torch.float32 assert (quant_config.w1_bias is None
assert w2_bias is None or w2_bias.dtype == torch.float32 or quant_config.w1_bias.dtype == torch.float32)
assert (quant_config.w2_bias is None
or quant_config.w2_bias.dtype == torch.float32)
# Shape check, only check non-mxfp4 # Shape check, only check non-mxfp4
assert hidden_states.shape[-1] == w1.shape[-2] assert hidden_states.shape[-1] == w1.shape[-2]
...@@ -130,20 +104,20 @@ def triton_kernel_fused_experts( ...@@ -130,20 +104,20 @@ def triton_kernel_fused_experts(
intermediate_cache1 = matmul_ogs( intermediate_cache1 = matmul_ogs(
hidden_states, hidden_states,
w1, w1,
w1_bias, quant_config.w1_bias,
routing_data, routing_data,
gather_indx=gather_indx, gather_indx=gather_indx,
precision_config=w1_precision, precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None, gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act) fused_activation=act)
intermediate_cache3 = matmul_ogs( intermediate_cache3 = matmul_ogs(
intermediate_cache1, intermediate_cache1,
w2, w2,
w2_bias, quant_config.w2_bias,
routing_data, routing_data,
scatter_indx=scatter_indx, scatter_indx=scatter_indx,
precision_config=w2_precision, precision_config=quant_config.w2_precision,
gammas=None if apply_router_weight_on_input else gammas, gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor, y=output_tensor,
) )
...@@ -154,21 +128,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -154,21 +128,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
quant_config,
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
w1_precision: "PrecisionConfig", quant_config: FusedMoEQuantConfig,
w2_precision: "PrecisionConfig",
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
): ):
super().__init__(quant_config) super().__init__(quant_config)
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
self.w1_precision = w1_precision
self.w2_precision = w2_precision
self.w1_bias = w1_bias
self.w2_bias = w2_bias
@property @property
def activation_formats( def activation_formats(
...@@ -212,12 +178,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -212,12 +178,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
...@@ -228,20 +189,12 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -228,20 +189,12 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states, hidden_states,
w1, w1,
w2, w2,
None, routing_data=None,
None, gather_indx=None,
None, scatter_indx=None,
activation=activation, activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False, apply_router_weight_on_input=False,
use_fp8_w8a8=False,
per_channel_quant=False,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale, a1q_scale=a1q_scale)
w2_scale=w2_scale,
w1_bias=self.w1_bias,
w2_bias=self.w2_bias,
w1_precision=self.w1_precision,
w2_precision=self.w2_precision,
a1_scale=a1q_scale,
a2_scale=a2_scale)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment