Unverified Commit 5f76b3fb authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE] Convert CT W8A8 To Oracle Structure (#39187)


Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
parent 809d83c2
......@@ -928,16 +928,16 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
p.is_cuda() and p.has_device_capability((8, 9))
)
SUPPORTED_W_A_FP8 = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) == (None, None) or (
device_supports_fp8 and (weight_key, activation_key) in SUPPORTED_W_A_FP8
)
supported: list[tuple[QuantKey | None, QuantKey | None]] = [(None, None)]
if device_supports_fp8:
supported += [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) in supported
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
......
......@@ -46,6 +46,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kInt8DynamicTokenSym,
kInt8StaticChannelSym,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
......@@ -1952,18 +1954,24 @@ class TritonExperts(mk.FusedMoEExpertsModular):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
if not current_platform.supports_fp8():
return (weight_key, activation_key) == (None, None)
SUPPORTED_W_A = [
(None, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
# INT8 requires at least 7.5 (Turing).
device_supports_int8 = (
current_platform.is_cuda()
and current_platform.has_device_capability((7, 5))
)
supported: list[tuple[QuantKey | None, QuantKey | None]] = [(None, None)]
if device_supports_int8:
supported.append((kInt8StaticChannelSym, kInt8DynamicTokenSym))
if current_platform.supports_fp8():
supported += [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) in supported
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
......@@ -11,46 +14,165 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kInt8DynamicTokenSym,
kInt8StaticChannelSym,
)
logger = init_logger(__name__)
class Int8MoeBackend(Enum):
TRITON = "TRITON"
def _get_priority_backends(
moe_config: FusedMoEConfig,
) -> list[Int8MoeBackend]:
"""
Get available backends in priority order based on platform and config.
"""
return [Int8MoeBackend.TRITON]
def backend_to_kernel_cls(
backend: Int8MoeBackend,
) -> list[type[mk.FusedMoEExperts]]:
if backend == Int8MoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
)
return [TritonExperts]
else:
raise ValueError(f"Unknown Int8 MoE backend: {backend.value}")
def map_int8_backend(runner_backend: MoEBackend) -> Int8MoeBackend:
"""Map user's MoEBackend to Int8MoeBackend."""
mapping = {
"triton": Int8MoeBackend.TRITON,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for Int8 MoE. "
f"Expected one of {list(mapping.keys())}."
)
def select_int8_moe_backend(
config: FusedMoEConfig,
) -> type[mk.FusedMoEExperts]:
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
supported, reason = TritonExperts.is_supported_config(
TritonExperts,
config,
None,
None,
mk.FusedMoEActivationFormat.Standard,
weight_key: QuantKey | None = kInt8StaticChannelSym,
activation_key: QuantKey | None = kInt8DynamicTokenSym,
) -> tuple[Int8MoeBackend, type[mk.FusedMoEExperts]]:
"""
Select the primary Int8 MoE backend.
Note: Shape-specific fallbacks may still occur at runtime.
"""
if config.is_lora_enabled:
return Int8MoeBackend.TRITON, backend_to_kernel_cls(Int8MoeBackend.TRITON)[0]
AVAILABLE_BACKENDS = _get_priority_backends(config)
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
if not supported:
raise ValueError(
f"INT8 Triton MoE backend does not support the "
f"deployment configuration: {reason}"
def _make_log_backend(backend: Int8MoeBackend) -> str:
available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
return (
f"Using {backend.value} Int8 MoE backend out "
f"of potential backends: {available_backend_strs}."
)
logger.info_once("Using Triton INT8 MoE backend", scope="local")
return TritonExperts
def _make_log_unsupported(backend: Int8MoeBackend, reason: str | None) -> str:
if reason:
return (
f"Int8 MoE backend {backend.value} does not support the "
f"deployment configuration since {reason}."
)
else:
return (
f"Int8 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: Int8MoeBackend,
) -> tuple[Int8MoeBackend, type[mk.FusedMoEExperts]]:
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_int8_backend(runner_backend)
return _return_or_raise(requested_backend)
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No Int8 MoE backend supports the deployment configuration."
)
def make_int8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
per_act_token_quant: bool = False,
) -> FusedMoEQuantConfig:
return int8_w8a16_moe_quant_config(
assert (a1_scale is None and a2_scale is None) or (
a1_scale is not None and a2_scale is not None
), "a1_scale and a2_scale must both be provided or both be None"
if a1_scale is None or a2_scale is None:
return int8_w8a16_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
)
return int8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant,
)
......@@ -61,24 +183,39 @@ def make_int8_moe_kernel(
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: SharedExperts | None = None,
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local")
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
return mk.FusedMoEKernel(
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=shared_experts,
inplace=not moe_config.disable_inplace,
)
return kernel
......@@ -8,6 +8,7 @@ from compressed_tensors.quantization import (
QuantizationStrategy,
)
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
......@@ -16,17 +17,27 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
int8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.oracle.int8 import (
make_int8_moe_kernel,
make_int8_moe_quant_config,
select_int8_moe_backend,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kInt8DynamicTokenSym,
kInt8StaticChannelSym,
)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"""W8A8 Int8 MoE quantization using compressed tensors."""
def __init__(
self,
weight_quant: QuantizationArgs,
......@@ -56,6 +67,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"dynamic per token quantization. Found static input scales."
)
# Select Int8 MoE backend.
self.int8_backend, self.experts_cls = select_int8_moe_backend(
config=self.moe,
weight_key=kInt8StaticChannelSym,
activation_key=kInt8DynamicTokenSym,
)
def create_weights(
self,
layer: torch.nn.Module,
......@@ -122,13 +140,28 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def process_weights_after_loading(self, layer: FusedMoE) -> None:
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.experts_cls is not None
self.moe_kernel = make_int8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return int8_w8a8_moe_quant_config(
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
return make_int8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
......@@ -144,18 +177,17 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
......@@ -7,7 +7,6 @@ import torch
from torch.nn import Module
if TYPE_CHECKING:
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
......@@ -21,6 +20,10 @@ from vllm.model_executor.layers.fused_moe.oracle.int8 import (
from vllm.model_executor.layers.quantization.online.moe_base import (
OnlineMoEMethodBase,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kInt8DynamicTokenSym,
kInt8StaticChannelSym,
)
from vllm.model_executor.utils import replace_parameter
......@@ -35,8 +38,10 @@ class Int8OnlineMoEMethod(OnlineMoEMethodBase):
layer: torch.nn.Module,
):
super().__init__(layer.moe_config)
self.experts_cls: type[mk.FusedMoEExperts] = select_int8_moe_backend(
self.int8_backend, self.experts_cls = select_int8_moe_backend(
config=self.moe,
weight_key=kInt8StaticChannelSym,
activation_key=kInt8DynamicTokenSym,
)
def process_weights_after_loading(self, layer: Module) -> None:
......
......@@ -170,6 +170,9 @@ kMxfp8Dynamic = QuantKey(FP8_DTYPE, scale=kMxfp8DynamicGroupScale, symmetric=Tru
kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
kInt8StaticChannelSym = QuantKey(torch.int8, kStaticChannelScale, symmetric=True)
kInt8DynamicTokenSym = QuantKey(torch.int8, kDynamicTokenScale, symmetric=True)
def create_fp8_quant_key(
static: bool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment