Unverified Commit 84166fee authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)


Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 6e0cd10f
...@@ -9,6 +9,9 @@ from typing import Callable, Optional, Union ...@@ -9,6 +9,9 @@ from typing import Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs import vllm.envs as envs
...@@ -210,6 +213,7 @@ class MoEConfig: ...@@ -210,6 +213,7 @@ class MoEConfig:
moe_parallel_config: FusedMoEParallelConfig moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type. in_dtype: torch.dtype # The activation type.
quant_dtype: torch.dtype = None
# TODO: add more quantization params, blocked, per-token, etc. # TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128 block_size: int = 128
...@@ -264,8 +268,22 @@ class FusedMoeWeightScaleSupported(Enum): ...@@ -264,8 +268,22 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK = "block" BLOCK = "block"
def get_quant_config_input_activations(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(
"input_activations")
else:
return None
class FusedMoEMethodBase(QuantizeMethodBase): class FusedMoEMethodBase(QuantizeMethodBase):
moe: MoEConfig
@abstractmethod @abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
...@@ -277,6 +295,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -277,6 +295,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
self.moe = moe
quant_dtype = None quant_dtype = None
act_quant_block_size = None act_quant_block_size = None
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
...@@ -297,12 +316,13 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -297,12 +316,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# dp_size actually means tp_size, bug in pplx kernels # dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size, dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim, hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
# For blocked per token: set to # For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32) # ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32) # For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( hidden_dim_scale_bytes=(
(moe.hidden_dim + moe.block_size - 1) // moe.block_size * 0 if moe.quant_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)), torch.float32.itemsize)),
) )
...@@ -313,6 +333,9 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -313,6 +333,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
handle = all2all_manager.get_handle(all_to_all_args) handle = all2all_manager.get_handle(all_to_all_args)
input_activations = get_quant_config_input_activations(
quant_config)
prepare_finalize = PplxPrepareAndFinalize( prepare_finalize = PplxPrepareAndFinalize(
handle, handle,
max_num_tokens=moe.max_num_tokens, max_num_tokens=moe.max_num_tokens,
...@@ -320,7 +343,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -320,7 +343,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
rank=all2all_manager.rank, rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels # dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size, dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype, quant_dtype=moe.quant_dtype,
per_act_token=(input_activations.strategy
== QuantizationStrategy.TOKEN
if input_activations is not None else False),
) )
elif moe.use_deepep_ht_kernels: elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size assert moe.dp_size == all2all_manager.dp_world_size
...@@ -365,15 +391,15 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -365,15 +391,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.topk_indices_dtype = None self.topk_indices_dtype = None
if prepare_finalize is not None: if prepare_finalize is not None:
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize) experts = self.select_gemm_impl(prepare_finalize, moe)
self.fused_experts = FusedMoEModularKernel( self.fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
) )
def select_gemm_impl( def select_gemm_impl(
self, prepare_finalize: FusedMoEPrepareAndFinalize self, prepare_finalize: FusedMoEPrepareAndFinalize,
) -> FusedMoEPermuteExpertsUnpermute: moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate # based on the all2all implementation, select the appropriate
# gemm implementation # gemm implementation
raise NotImplementedError( raise NotImplementedError(
...@@ -419,7 +445,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -419,7 +445,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else: else:
self.rocm_aiter_fused_experts = None # type: ignore self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize): def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]):
assert self.fused_experts == fused_experts assert self.fused_experts == fused_experts
...@@ -809,7 +836,6 @@ class FusedMoE(torch.nn.Module): ...@@ -809,7 +836,6 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu", activation: str = "silu",
): ):
super().__init__() super().__init__()
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
...@@ -869,14 +895,24 @@ class FusedMoE(torch.nn.Module): ...@@ -869,14 +895,24 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
# Only support float8 for now.
quant_dtype = params_dtype
if quant_config is not None:
input_activations = get_quant_config_input_activations(
quant_config)
if (input_activations is not None
and input_activations.num_bits == 8
and input_activations.type == QuantizationType.FLOAT):
quant_dtype = torch.float8_e4m3fn
moe = MoEConfig( moe = MoEConfig(
num_experts=self.global_num_experts, num_experts=self.global_num_experts,
experts_per_token=top_k, experts_per_token=top_k,
hidden_dim=hidden_size, hidden_dim=hidden_size,
num_local_experts=self.local_num_experts, num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config, moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype, in_dtype=params_dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE, max_num_tokens=MOE_DP_CHUNK_SIZE,
) )
self.moe_config = moe self.moe_config = moe
......
...@@ -175,6 +175,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -175,6 +175,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -309,7 +310,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -309,7 +310,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# Use a1 here to decipher the correct workspace datatype # Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = ( workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, M, N, K, top_k, self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts)) global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time # We can reuse the memory between cache1 and cache3 because by the time
......
...@@ -21,7 +21,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -21,7 +21,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
rank: int, rank: int,
dp_size: int, dp_size: int,
quant_dtype: Optional[torch.dtype] = None, quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None): block_shape: Optional[list[int]] = None,
per_act_token: bool = False):
super().__init__() super().__init__()
assert max_num_tokens > 0 assert max_num_tokens > 0
self.a2a = a2a self.a2a = a2a
...@@ -31,6 +32,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -31,6 +32,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.rank = rank self.rank = rank
self.dp_size = dp_size self.dp_size = dp_size
self.quant_dtype = quant_dtype self.quant_dtype = quant_dtype
self.per_act_token = per_act_token
def max_num_tokens_per_rank(self) -> Optional[int]: def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens return self.max_num_tokens
...@@ -66,13 +68,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -66,13 +68,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1") "apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype) a1 = a1 * rank_topk_weights.to(a1.dtype)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( repeat_cols = 4
a2_scale.numel() != 1 if a2_scale is not None else False) repeat_rows = 1 if self.per_act_token else a1.shape[0]
a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
self.per_act_token, self.block_shape)
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, if a1q_scale is not None:
self.quant_dtype, a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
per_act_token,
self.block_shape)
# rem_experts need to be 0 for pplx to work properly. # rem_experts need to be 0 for pplx to work properly.
rem_experts = num_experts % self.world_size rem_experts = num_experts % self.world_size
...@@ -100,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -100,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else 1) * float32_size else 1) * float32_size
expert_x_scale = torch.empty( expert_x_scale = torch.empty(
( (
num_experts, num_local_experts,
expert_x.size(1), expert_x.size(1),
(expert_x.size(2) + block_size - 1) // block_size, (expert_x.size(2) + block_size - 1) // block_size,
), ),
...@@ -121,6 +124,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -121,6 +124,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
indices=rank_topk_ids, indices=rank_topk_ids,
bound_m=bound_m, bound_m=bound_m,
) )
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1]
return expert_x, expert_x_scale, expert_num_tokens, None, None return expert_x, expert_x_scale, expert_num_tokens, None, None
......
...@@ -37,6 +37,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -37,6 +37,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -49,9 +50,9 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -49,9 +50,9 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes( return self.deep_gemm_expert.workspace_shapes(
a, M, N, K, topk, num_experts) a, aq, M, N, K, topk, num_experts)
else: else:
return self.triton_expert.workspace_shapes(a, M, N, K, topk, return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
num_experts) num_experts)
def apply( def apply(
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum import enum
import importlib
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
...@@ -11,7 +12,6 @@ from compressed_tensors.quantization import (ActivationOrdering, ...@@ -11,7 +12,6 @@ from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy) QuantizationStrategy)
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
...@@ -30,6 +30,15 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -30,6 +30,15 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if current_platform.is_cuda_alike():
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
if has_pplx:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -77,8 +86,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -77,8 +86,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
else: else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config) return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
and layer.activation == "silu"):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant): elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config) return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
...@@ -421,6 +429,11 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -421,6 +429,11 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or " "For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.") "channelwise, dynamic per token quantization.")
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
self.fused_experts = cutlass_moe_fp8 # type: ignore
self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
...@@ -499,25 +512,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -499,25 +512,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
device = w13_weight.device
# TODO strides can be shared across multiple layers
self.ab_strides1 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full((num_experts, ),
2 * intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full((num_experts, ),
intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides2 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ. # We take the max of all the scales in case they differ.
...@@ -558,6 +552,27 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -558,6 +552,27 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)
def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8)
assert moe is not None
max_experts_per_worker = (
(moe.num_experts + prepare_finalize.world_size - 1) //
prepare_finalize.world_size)
experts = CutlassExpertsFp8(
max_experts_per_worker, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
if has_pplx and isinstance(
prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# no expert_map support in this case
self.disable_expert_map = True
return experts
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -577,9 +592,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -577,9 +592,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", (
f"{activation} not supported for Cutlass MoE.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -590,27 +602,22 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -590,27 +602,22 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32)
from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8 return self.fused_experts(
return cutlass_moe_fp8(
x, x,
layer.w13_weight.transpose(1, 2), layer.w13_weight,
layer.w2_weight.transpose(1, 2), layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
self.ab_strides1, activation=activation,
self.c_strides1, global_num_experts=global_num_experts,
self.ab_strides2, expert_map=None if self.disable_expert_map else expert_map,
self.c_strides2, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
out_dtype=x.dtype,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
) )
......
...@@ -769,7 +769,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -769,7 +769,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize): def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts) BatchedTritonOrDeepGemmExperts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment