"tests/vscode:/vscode.git/clone" did not exist on "f708bd4904ee15bdf9e86503439f2408aa754cda"
Unverified Commit 8ad7285e authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra...


[Kernels] Clean up FusedMoeMethodBase and modular kernel setup.  Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 48b01fd4
......@@ -49,7 +49,8 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4,
cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
......@@ -69,6 +70,7 @@ if HAS_TRITON:
"cutlass_moe_fp8",
"cutlass_moe_fp4",
"CutlassExpertsFp8",
"CutlassBatchedExpertsFp8",
"TritonExperts",
"BatchedTritonExperts",
"DeepGemmExperts",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
......@@ -254,18 +254,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
......@@ -132,18 +132,28 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
expert_tokens_metadata)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None
......@@ -151,4 +161,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta,
apply_router_weight_on_input, extra_expert_args)
apply_router_weight_on_input)
......@@ -45,7 +45,6 @@ def get_quant_config_weight_quant(
return _get_quant_config_quantization_args(quant_config, "weights")
# TODO (bnell): use scalar_type instead of bools?
def get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
......@@ -65,7 +64,8 @@ def get_config_quant_dtype(
@dataclass
class FusedMoEQuantConfig:
# The post quantization activation type.
quant_dtype: Optional[torch.dtype] = None
# TODO (bnell): use scalar_type instead of Union.
quant_dtype: Union[torch.dtype, str, None] = None
per_act_token_quant: bool = False
per_out_ch_quant: bool = False
block_shape: Optional[list[int]] = None
......@@ -141,6 +141,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_mxfp4_w4a4,
]
]) <= 1, "Quantization flags are mutually exclusive."
......@@ -334,7 +335,7 @@ class FusedMoEConfig:
assert self.max_num_tokens > 0
@property
def quant_dtype(self) -> Optional[torch.dtype]:
def quant_dtype(self) -> Union[torch.dtype, str, None]:
if self.quant_config is not None:
return self.quant_config.quant_dtype
else:
......@@ -429,7 +430,7 @@ class FusedMoEConfig:
block_shape = None
per_act_token_quant = False
per_out_ch_quant = False
quant_dtype: Optional[torch.dtype] = None
quant_dtype: Union[torch.dtype, str, None] = None
input_quant = get_quant_config_input_quant(quant_config)
weight_quant = get_quant_config_weight_quant(quant_config)
......@@ -453,7 +454,7 @@ class FusedMoEConfig:
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
ModelOptNvFp4Config):
quant_dtype = torch.uint8
quant_dtype = "nvfp4"
if weight_quant is not None:
per_out_ch_quant = (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels."""
from typing import Any, Callable, Optional
from typing import Callable, Optional
import torch
......@@ -12,11 +12,10 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache,
extract_required_args)
_resize_cache)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
......@@ -213,19 +212,14 @@ def run_cutlass_moe_fp8(
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# TODO (bnell): split class batched vs. non-batched?
# maybe remove need for passing aq to workspace_shapes
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
):
super().__init__(
FusedMoEQuantConfig(
......@@ -234,34 +228,139 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
self.out_dtype = out_dtype
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = self.activation_formats[
0] == mk.FusedMoEActivationFormat.BatchedExperts
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format)
class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
output = (M * topk, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
assert max_experts_per_worker > 0
assert not use_batched_format or num_dispatchers is not None
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
else:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return not self.use_batched_format
return False
def supports_expert_map(self) -> bool:
return not self.use_batched_format
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
def workspace_shapes(
self,
a: torch.Tensor,
......@@ -274,55 +373,16 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.size(1)
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
(N // 2))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
else:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
output = (M * topk, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i)
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
def cutlass_moe_fp8(
a: torch.Tensor,
......@@ -387,11 +447,9 @@ def cutlass_moe_fp8(
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
max_experts_per_worker=num_experts,
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
use_batched_format=False,
),
)
......@@ -476,8 +534,9 @@ def run_cutlass_moe_fp4(
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
" between weights.")
assert (e_w1 == e_w2
and e_w1 == e), ("Number of experts must match",
f" between weights. {e_w1}, {e_w2}, {e}")
assert (k_a == half_k_w1 * 2
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
......@@ -554,6 +613,10 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token_quant: bool,
......@@ -562,8 +625,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
use_batched_format: bool = False,
):
super().__init__(
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
FusedMoEQuantConfig(
quant_dtype=torch.uint8,
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
......@@ -572,6 +639,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
# TODO(bnell): put this stuff into quant config?
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
@property
def activation_formats(
self
......@@ -590,8 +663,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
......@@ -620,34 +692,42 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor,
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
required_keys = [
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
"e", "device"
]
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
device) = extract_required_args(extra_expert_args, required_keys)
):
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
n = w2.shape[2] * 2
run_cutlass_moe_fp4(
output=output,
a=hidden_states,
a1_gscale=a1_gscale,
a1_gscale=self.a1_gscale,
w1_fp4=w1,
w1_blockscale=w1_scale,
w1_alphas=g1_alphas,
a2_gscale=a2_gscale,
w1_alphas=self.g1_alphas,
a2_gscale=self.a2_gscale,
w2_fp4=w2,
w2_blockscale=w2_scale,
w2_alphas=g2_alphas,
w2_alphas=self.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace13=workspace13,
......@@ -656,7 +736,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
n=n,
k=k,
e=e,
device=device,
device=hidden_states.device,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -677,7 +757,6 @@ def cutlass_moe_fp4(
n: int,
k: int,
e: int,
device: torch.device,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False) -> torch.Tensor:
assert expert_map is None, ("Expert Parallelism / expert_map "
......@@ -686,6 +765,10 @@ def cutlass_moe_fp4(
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
g1_alphas,
g2_alphas,
a1_gscale,
a2_gscale,
max_experts_per_worker=e,
out_dtype=a.dtype,
per_act_token_quant=False,
......@@ -693,29 +776,7 @@ def cutlass_moe_fp4(
use_batched_format=False,
),
)
extra_expert_args = {
'g1_alphas': g1_alphas,
'g2_alphas': g2_alphas,
'a1_gscale': a1_gscale,
'a2_gscale': a2_gscale,
'm': m,
'n': n,
'k': k,
'e': e,
'device': device,
}
# NVFP4 requires two levels of quantization, which involves computing some
# scaling factors dynamically. This makes it incompatible with the typical
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
# MoE body.
extra_prepare_args = {
'skip_quant': True,
}
# Similar reason as above.
extra_finalize_args = {
'skip_weight_reduce': True,
}
return fn(
hidden_states=a,
w1=w1_fp4,
......@@ -731,9 +792,6 @@ def cutlass_moe_fp4(
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
......@@ -824,16 +882,6 @@ def run_cutlass_block_scaled_fused_experts(
k = w1_q.size(1)
n = w2_q.size(1)
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device="cuda")
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
topk = topk_ids.size(1)
a_q, a1_scale = _fp8_quantize(a,
......@@ -842,6 +890,16 @@ def run_cutlass_block_scaled_fused_experts(
block_shape=[128, 128])
device = a_q.device
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Any, Optional
from typing import Optional
import torch
from tqdm import tqdm
......@@ -230,7 +230,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
assert self.block_shape is not None
assert a1q_scale is not None
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import deep_ep
import torch
......@@ -127,12 +127,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights)
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
......@@ -187,11 +191,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
assert self.handle is not None
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
from typing import Optional, Union
import deep_ep
import torch
......@@ -77,7 +77,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
......@@ -111,12 +111,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return x, x_scales
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
......@@ -162,11 +166,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional, Union
import torch
......@@ -8,8 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
TopKWeightAndReduceNoOP)
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
has_flashinfer_cutlass_fused_moe)
......@@ -43,31 +42,34 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_nvfp4_w4a4: bool = False,
use_fp8_w8a8: bool = False,
use_dp: bool = False,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
out_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
ep_rank: int = 0,
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.uint8,
quant_dtype=quant_dtype,
per_act_token_quant=False,
block_shape=None,
))
self.use_nvfp4_w4a4 = use_nvfp4_w4a4
self.use_fp8_w8a8 = use_fp8_w8a8
assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is "
"currently supported.")
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
self.tp_size = tp_size
self.use_dp = use_dp
assert not use_batched_format or num_dispatchers is not None
self.num_dispatchers = num_dispatchers
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
self.out_dtype = out_dtype
@property
def activation_formats(
......@@ -84,8 +86,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
......@@ -117,8 +118,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
aq_m, aq_n = aq.shape
workspace2 = ()
output_shape = (aq_m, aq_n * 2)
......@@ -149,21 +148,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: Optional[bool],
extra_expert_args: Optional[dict[str, Any]],
):
assert extra_expert_args is not None, \
"extra_expert_args must be provided"
required_keys = [
'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype'
]
g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = (
extract_required_args(extra_expert_args, required_keys))
# Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale.
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
# Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, (
......@@ -171,12 +158,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"be None for FlashInferExperts")
quant_scales = [
a1_gscale,
self.a1_gscale,
w1_scale.view(torch.int32),
g1_alphas,
a2_gscale,
self.g1_alphas,
self.a2_gscale,
w2_scale.view(torch.int32),
g2_alphas,
self.g2_alphas,
]
_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
......@@ -185,7 +172,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights=w1.view(torch.long),
fc2_expert_weights=w2.view(torch.long),
output_dtype=out_dtype,
output_dtype=self.out_dtype,
quant_scales=quant_scales,
input_sf=a1q_scale,
tp_size=self.tp_size,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
......@@ -9,7 +9,7 @@ from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
extract_required_args, moe_kernel_quantize_input)
moe_kernel_quantize_input)
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
......@@ -21,16 +21,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
quant_dtype: Optional[torch.dtype] = None,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
use_dp: bool,
a1_gscale: Optional[torch.Tensor],
num_dispatchers: int = 1,
):
super().__init__()
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.quant_dtype = quant_dtype
self.num_dispatchers_ = num_dispatchers
self.use_dp = use_dp
self.a1_gscale = a1_gscale
self.local_tokens = None
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
......@@ -55,10 +54,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
......@@ -67,22 +67,22 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
(a1_gscale, use_dp, local_tokens) = extract_required_args(
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_gscale,
self.a1_gscale,
quant_config.quant_dtype,
self.per_channel_quant,
self.block_shape,
is_fp4_scale_swizzled=not use_dp, # Swizzling after communication
quant_config.per_act_token_quant,
quant_config.block_shape,
# Swizzling after communication
is_fp4_scale_swizzled=not self.use_dp,
)
if use_dp:
if self.use_dp:
topk_weights, topk_ids, a1q, a1q_scale = \
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes())
sizes=get_local_sizes(),
)
a1_m, a1_n = a1q.shape
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
......@@ -91,13 +91,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
(use_dp,
local_tokens) = extract_required_args(extra_finalize_args,
['use_dp', 'local_tokens'])
if use_dp:
if self.use_dp:
fused_expert_output = get_dp_group().reduce_scatterv(
fused_expert_output, dim=0, sizes=get_local_sizes())
output.copy_(fused_expert_output)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused batched MoE kernel."""
from typing import Any, Optional
from typing import Optional
import torch
......@@ -496,12 +496,16 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
......@@ -590,11 +594,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
weight_and_reduce_impl.apply(
......@@ -688,18 +696,28 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
else:
return t.to(f32) * group_broadcast(scale, t.shape)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
assert hidden_states.dim() == 3
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
......@@ -894,18 +912,28 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
......
......@@ -1394,9 +1394,9 @@ def fused_experts(hidden_states: torch.Tensor,
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used(
) or _valid_deep_gemm(hidden_states, w1, w2)
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
if (allow_deep_gemm and use_fp8_w8a8
and (is_blackwell_deep_gemm_e8m0_used()
or _valid_deep_gemm(hidden_states, w1, w2))):
assert apply_router_weight_on_input is False
assert is_act_and_mul, (
"DeepGemm only supports is_act_and_mul=True for now.")
......@@ -1905,7 +1905,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
# Check constraints.
if self.use_int4_w4a16:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
import torch
......@@ -8,7 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
from vllm.utils import has_triton_kernels
logger = init_logger(__name__)
......@@ -160,12 +159,16 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dispatchers: int,
w1_precision: "PrecisionConfig",
w2_precision: "PrecisionConfig",
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
):
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
self.w1_precision = w1_precision
self.w2_precision = w2_precision
self.w1_bias = w1_bias
self.w2_bias = w2_bias
@property
def activation_formats(
......@@ -219,11 +222,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
w1_bias, w2_bias = (extract_required_args(extra_expert_args,
["w1_bias", "w2_bias"]))
return triton_kernel_fused_experts(
output,
hidden_states,
......@@ -240,8 +239,8 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_bias=self.w1_bias,
w2_bias=self.w2_bias,
w1_precision=self.w1_precision,
w2_precision=self.w2_precision,
a1_scale=a1q_scale,
......
......@@ -37,7 +37,6 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
round_up)
from vllm.utils.flashinfer import has_flashinfer
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
......@@ -49,9 +48,6 @@ if current_platform.is_cuda_alike():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize)
if has_flashinfer():
from .flashinfer_cutlass_prepare_finalize import (
FlashInferCutlassMoEPrepareAndFinalize)
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
......@@ -80,7 +76,12 @@ class FusedMoeWeightScaleSupported(Enum):
class FusedMoEMethodBase(QuantizeMethodBase):
moe: FusedMoEConfig
# TODO(bnell): also pass quant_config?
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe = moe
self.fused_experts: Optional[Callable] = None
self.topk_indices_dtype = None
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
......@@ -99,16 +100,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return False
@staticmethod
def maybe_make_prepare_finalize(
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
def _maybe_make_prepare_finalize(
moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
if moe.use_flashinfer_cutlass_kernels:
prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize(
quant_dtype=moe.quant_dtype, )
assert not moe.use_flashinfer_cutlass_kernels, \
"Must be created in modelopt.py"
if moe.use_pplx_kernels:
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
......@@ -188,14 +189,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return prepare_finalize
def init_prepare_finalize(self, moe: FusedMoEConfig):
self.moe = moe
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(
self.moe)
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[FusedMoEPrepareAndFinalize]:
if moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
else:
return None
def init_prepare_finalize(self):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
self.topk_indices_dtype = None
if prepare_finalize is not None:
logger.debug("%s", prepare_finalize.__class__.__name__)
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
self, id(self))
assert self.topk_indices_dtype is None
assert self.fused_experts is None, \
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, self.moe)
self.fused_experts = FusedMoEModularKernel(
......@@ -214,12 +226,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize")
def maybe_swap_experts_impl(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
pass
@abstractmethod
def apply(
self,
......@@ -251,10 +257,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
self.moe = moe
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
......@@ -266,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
......@@ -474,12 +478,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
elif self.fused_experts is not None:
if self.has_bias:
raise ValueError(
"FusedMoEModularKernel does not support bias.")
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
else:
# add w1_bias/w2_bias to kwargs if they exist
kwargs = dict(
assert fused_experts is not None
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_bias=layer.w13_bias if self.has_bias else None,
w2_bias=layer.w2_bias if self.has_bias else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
......@@ -488,17 +510,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=global_num_experts,
expert_map=expert_map,
)
if isinstance(self.fused_experts,
FusedMoEModularKernel) and self.has_bias:
raise ValueError(
"FusedMoEModularKernel does not support bias.")
if self.has_bias:
kwargs.update({
"w1_bias": getattr(layer, "w13_bias", None),
"w2_bias": getattr(layer, "w2_bias", None),
})
return self.fused_experts(**kwargs)
def forward_cpu(
self,
......@@ -868,8 +879,6 @@ class FusedMoE(CustomOp):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
if isinstance(self.quant_method, FusedMoEMethodBase):
self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config)
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
......
......@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Any, Optional, final
from typing import Optional, final
import torch
......@@ -150,15 +150,23 @@ class FusedMoEPrepareAndFinalize(ABC):
@abstractmethod
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
......@@ -186,11 +194,15 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
@abstractmethod
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output.
......@@ -368,7 +380,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
"""
This function computes the intermediate result of a Mixture of Experts
......@@ -454,18 +465,27 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.activation_formats[0]}")
def _do_fused_experts(
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int, local_num_experts: int,
self,
fused_out: Optional[torch.Tensor],
a1: torch.Tensor,
a1q: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
......@@ -509,7 +529,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
)
return fused_out
......@@ -533,7 +553,6 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
......@@ -541,6 +560,9 @@ class FusedMoEModularKernel(torch.nn.Module):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
# TODO(bnell): get rid of one level here, update slice functions
# to nops on num_chunks==1
if not self.fused_experts.supports_chunking() or num_chunks == 1:
return self._do_fused_experts(
fused_out=None,
......@@ -562,7 +584,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
)
# Chunking required case
assert num_chunks > 1
......@@ -618,15 +640,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
m = None
if extra_expert_args is not None and 'm' in extra_expert_args:
m = extra_expert_args.get('m')
if extra_expert_args is not None:
chunked_extra_expert_args = extra_expert_args
else:
chunked_extra_expert_args = {}
for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx))
......@@ -637,11 +650,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map)
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
if m is not None:
chunked_extra_expert_args['m'] = e - s
self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx),
a1=a1,
......@@ -662,7 +670,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=chunked_extra_expert_args)
)
return fused_out
......@@ -684,9 +692,6 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
extra_expert_args: Optional[dict] = None,
extra_prepare_args: Optional[dict] = None,
extra_finalize_args: Optional[dict] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
......@@ -719,12 +724,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
fused_experts.apply.
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
to prepare.
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
to finalize.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
......@@ -748,7 +747,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
extra_prepare_args,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
......@@ -786,12 +784,15 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
)
self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids,
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
extra_finalize_args)
)
return output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional, Union
import pplx_kernels as pplx
import torch
......@@ -21,7 +21,7 @@ def pplx_hidden_dim_scale_bytes(
max_num_tokens: int,
hidden_dim: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
):
......@@ -32,6 +32,7 @@ def pplx_hidden_dim_scale_bytes(
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
if quant_dtype is not None:
assert isinstance(quant_dtype, torch.dtype)
assert quant_dtype.itemsize == 1
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
elem_size = torch.float32.itemsize
......@@ -89,12 +90,16 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
......@@ -213,11 +218,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
......@@ -38,7 +38,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]],
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
......@@ -50,27 +49,21 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
if (extra_prepare_args is not None
and extra_prepare_args.get("skip_quant", True)):
# Skip quantization if explicitly requested
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
if (extra_finalize_args is not None
and extra_finalize_args.get("skip_weight_reduce", True)):
assert output.shape == fused_expert_output.shape
output.copy_(fused_expert_output)
else:
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
......@@ -119,18 +119,28 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts,
expert_tokens_meta)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_e8m0_used()))
......@@ -158,5 +168,4 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2,
expert_tokens_meta,
apply_router_weight_on_input,
extra_expert_args,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import prod
from typing import Any, Optional, Union
from typing import Optional, Union
import torch
......@@ -189,7 +189,7 @@ def moe_kernel_quantize_input(
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.uint8: # nvfp4
elif quant_dtype == "nvfp4":
return _fp4_quantize(A,
A_scale,
is_sf_swizzled_layout=is_fp4_scale_swizzled)
......@@ -252,17 +252,3 @@ def _validate_scale_shape(
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def extract_required_args(
extra_args: Optional[dict[str, Any]],
required_keys: list[str],
) -> tuple[Any, ...]:
if extra_args is None:
raise ValueError("`extra_args` must be provided.")
missing_keys = [k for k in required_keys if k not in extra_args]
if missing_keys:
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
return tuple(extra_args[k] for k in required_keys)
......@@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig):
if isinstance(layer, FusedMoE):
if use_marlin:
return AWQMoEMethod(quant_args_marlin)
return AWQMoEMethod(quant_args_marlin, layer.moe)
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
......@@ -339,7 +339,7 @@ class AutoRoundConfig(QuantizationConfig):
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
return GPTQMarlinMoEMethod(quant_args_marlin)
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment