Unverified Commit 84166fee authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)


Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 6e0cd10f
......@@ -9,6 +9,9 @@ from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
......@@ -210,6 +213,7 @@ class MoEConfig:
moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type.
quant_dtype: torch.dtype = None
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
......@@ -264,8 +268,22 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK = "block"
def get_quant_config_input_activations(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(
"input_activations")
else:
return None
class FusedMoEMethodBase(QuantizeMethodBase):
moe: MoEConfig
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
......@@ -277,6 +295,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.moe = moe
quant_dtype = None
act_quant_block_size = None
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
......@@ -297,13 +316,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
hidden_dim_scale_bytes=(
0 if moe.quant_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
)
# Intranode pplx a2a takes a group name while internode does not.
......@@ -313,6 +333,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
handle = all2all_manager.get_handle(all_to_all_args)
input_activations = get_quant_config_input_activations(
quant_config)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
......@@ -320,7 +343,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype,
quant_dtype=moe.quant_dtype,
per_act_token=(input_activations.strategy
== QuantizationStrategy.TOKEN
if input_activations is not None else False),
)
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
......@@ -365,15 +391,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.topk_indices_dtype = None
if prepare_finalize is not None:
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize)
experts = self.select_gemm_impl(prepare_finalize, moe)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
def select_gemm_impl(
self, prepare_finalize: FusedMoEPrepareAndFinalize
) -> FusedMoEPermuteExpertsUnpermute:
self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
......@@ -419,7 +445,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize):
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]):
assert self.fused_experts == fused_experts
......@@ -809,7 +836,6 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu",
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
......@@ -869,14 +895,24 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
# Only support float8 for now.
quant_dtype = params_dtype
if quant_config is not None:
input_activations = get_quant_config_input_activations(
quant_config)
if (input_activations is not None
and input_activations.num_bits == 8
and input_activations.type == QuantizationType.FLOAT):
quant_dtype = torch.float8_e4m3fn
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe
......
......@@ -175,6 +175,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -309,7 +310,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time
......
......@@ -21,7 +21,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
rank: int,
dp_size: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
block_shape: Optional[list[int]] = None,
per_act_token: bool = False):
super().__init__()
assert max_num_tokens > 0
self.a2a = a2a
......@@ -31,6 +32,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.rank = rank
self.dp_size = dp_size
self.quant_dtype = quant_dtype
self.per_act_token = per_act_token
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens
......@@ -66,13 +68,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
repeat_cols = 4
repeat_rows = 1 if self.per_act_token else a1.shape[0]
a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
self.per_act_token, self.block_shape)
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
self.quant_dtype,
per_act_token,
self.block_shape)
if a1q_scale is not None:
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
# rem_experts need to be 0 for pplx to work properly.
rem_experts = num_experts % self.world_size
......@@ -100,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else 1) * float32_size
expert_x_scale = torch.empty(
(
num_experts,
num_local_experts,
expert_x.size(1),
(expert_x.size(2) + block_size - 1) // block_size,
),
......@@ -121,6 +124,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
indices=rank_topk_ids,
bound_m=bound_m,
)
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1]
return expert_x, expert_x_scale, expert_num_tokens, None, None
......
......@@ -37,6 +37,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -49,9 +50,9 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
a, M, N, K, topk, num_experts)
a, aq, M, N, K, topk, num_experts)
else:
return self.triton_expert.workspace_shapes(a, M, N, K, topk,
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
num_experts)
def apply(
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import importlib
from enum import Enum
from typing import Callable, Optional
......@@ -11,7 +12,6 @@ from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
......@@ -30,6 +30,15 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if current_platform.is_cuda_alike():
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
if has_pplx:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
logger = init_logger(__name__)
......@@ -77,8 +86,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and layer.activation == "silu"):
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
......@@ -421,6 +429,11 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
self.fused_experts = cutlass_moe_fp8 # type: ignore
self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
......@@ -499,25 +512,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
device = w13_weight.device
# TODO strides can be shared across multiple layers
self.ab_strides1 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full((num_experts, ),
2 * intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full((num_experts, ),
intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides2 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
......@@ -558,6 +552,27 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8)
assert moe is not None
max_experts_per_worker = (
(moe.num_experts + prepare_finalize.world_size - 1) //
prepare_finalize.world_size)
experts = CutlassExpertsFp8(
max_experts_per_worker, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
if has_pplx and isinstance(
prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# no expert_map support in this case
self.disable_expert_map = True
return experts
def apply(
self,
layer: torch.nn.Module,
......@@ -577,9 +592,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", (
f"{activation} not supported for Cutlass MoE.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
......@@ -590,27 +602,22 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32)
return cutlass_moe_fp8(
return self.fused_experts(
x,
layer.w13_weight.transpose(1, 2),
layer.w2_weight.transpose(1, 2),
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
self.ab_strides1,
self.c_strides1,
self.ab_strides2,
self.c_strides2,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
out_dtype=x.dtype,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......
......@@ -769,7 +769,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize):
def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment