Commit 0dc059af authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.15.1-dev' into v0.15.1-dev

# Conflicts:
#	vllm/v1/worker/gpu_model_runner.py
parents c3270a92 ca9ce18d
...@@ -361,12 +361,12 @@ void static_scaled_fp8_quant( ...@@ -361,12 +361,12 @@ void static_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt); std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale); torch::Tensor& scale);
// void dynamic_per_token_scaled_fp8_quant( void dynamic_per_token_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// std::optional<torch::Tensor> const& scale_ub); std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd( void selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
......
...@@ -625,19 +625,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -625,19 +625,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// ops.def( ops.def(
// "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
// "-> " "-> "
// "()"); "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// ops.def( ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! scale, Tensor? scale_ub) -> " "Tensor! scale, Tensor? scale_ub) -> "
// "()"); "()");
// ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
// &dynamic_per_token_scaled_fp8_quant); &dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor. // Compute int8 quantized tensor for given scaling factor.
ops.def( ops.def(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, Optional from typing import TYPE_CHECKING, Literal, Optional, Union
import torch import torch
...@@ -1900,6 +1900,28 @@ def scaled_fp4_experts_quant( ...@@ -1900,6 +1900,28 @@ def scaled_fp4_experts_quant(
output_scales = output_scales.view(torch.float8_e4m3fn) output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales return output, output_scales
def _lightop_per_token_quant_fp8_impl(
out: torch.Tensor,
input: torch.Tensor,
scales: torch.Tensor,
) -> None:
from lightop import op
op.per_token_quant_fp8(out, input, scales)
def _lightop_per_token_quant_fp8_fake(
out: torch.Tensor,
input: torch.Tensor,
scales: torch.Tensor,
) -> None:
pass
direct_register_custom_op(
"lightop_per_token_quant_fp8",
_lightop_per_token_quant_fp8_impl,
mutates_args=["out", "scales"],
fake_impl=_lightop_per_token_quant_fp8_fake,
)
def scaled_fp8_quant( def scaled_fp8_quant(
input: torch.Tensor, input: torch.Tensor,
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
...@@ -1952,7 +1974,74 @@ def scaled_fp8_quant( ...@@ -1952,7 +1974,74 @@ def scaled_fp8_quant(
dtype=torch.float32) dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant( # torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub) # output, input.contiguous(), scale, scale_ub)
output, scale = per_token_quant_fp8(input.contiguous()) # output, scale = per_token_quant_fp8(input.contiguous())
output = torch.empty_like(input, device=input.device, dtype=torch.float8_e4m3fn)
scale = torch.empty(shape[:-1] + (1, ),
device=input.device,
dtype=torch.float32)
torch.ops.vllm.lightop_per_token_quant_fp8(output, input, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
def scaled_fp8_quant_weight(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
group_shape: Optional[tuple[int, int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input.contiguous(), scale, scale_ub)
# output, scale = per_token_quant_fp8(input.contiguous())
else: else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32) scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
...@@ -2027,76 +2116,6 @@ def silu_and_mul_scaled_fp4_experts_quant( ...@@ -2027,76 +2116,6 @@ def silu_and_mul_scaled_fp4_experts_quant(
return output, output_scales return output, output_scales
# fp8
# def scaled_fp8_quant(
# input: torch.Tensor,
# scale: torch.Tensor | None = None,
# num_token_padding: int | None = None,
# scale_ub: torch.Tensor | None = None,
# use_per_token_if_dynamic: bool = False,
# output: torch.Tensor | None = None,
# group_shape: tuple[int, int] | None = None,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8 (must be 2D: [M, N])
# scale: Optional scaling factor for the FP8 quantization. Supports:
# - 0D or [1]: per-tensor scaling
# - 1D: requires explicit group_shape to disambiguate per-channel
# vs per-token (use (-1, 1) for per-channel, (1, -1) for per-token)
# - 2D [M/group_m, N/group_n]: group scaling (e.g. [M, N/128] for
# DeepSeek-style (1,128) groups, or [M/128, N/128] for (128,128))
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# group_shape: Optional tuple (group_m, group_n) specifying the group
# shape for static quantization. Use -1 for "full extent" (e.g.,
# (-1, -1) for per-tensor, (-1, 1) for per-channel, etc.)
# Required for 1D scales; optional for 2D scales.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert input.ndim == 2
# shape: tuple[int, int] | torch.Size = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input, scale, scale_ub
# )
# else:
# scale = torch.empty(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# torch.ops._C.static_scaled_fp8_quant(output, input, scale, group_shape)
# return output, scale
# gptq allspark # gptq allspark
def allspark_repack_weight( def allspark_repack_weight(
qweight: torch.Tensor, qweight: torch.Tensor,
......
...@@ -167,6 +167,7 @@ if TYPE_CHECKING: ...@@ -167,6 +167,7 @@ if TYPE_CHECKING:
VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_MOE_USE_DEEP_GEMM: bool = True
VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True
VLLM_USE_AITER_MOE_W8A8: bool = True
VLLM_DEEP_GEMM_WARMUP: Literal[ VLLM_DEEP_GEMM_WARMUP: Literal[
"skip", "skip",
"full", "full",
...@@ -1290,6 +1291,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1290,6 +1291,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool(
int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1"))
), ),
"VLLM_USE_AITER_MOE_W8A8": lambda: bool(
int(os.getenv("VLLM_USE_AITER_MOE_W8A8", "1"))
),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no # JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine # JIT'ing in the hot-path. However, this warmup increases the engine
......
...@@ -157,11 +157,8 @@ def maybe_make_prepare_finalize( ...@@ -157,11 +157,8 @@ def maybe_make_prepare_finalize(
# Note: We may want to use FP8 dispatch just to reduce # Note: We may want to use FP8 dispatch just to reduce
# data movement. # data movement.
use_fp8_dispatch = ( use_fp8_dispatch = quant_config.quant_dtype == current_platform.fp8_dtype()
quant_config.quant_dtype == current_platform.fp8_dtype()
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE
)
use_int8_dispatch = quant_config.quant_dtype == torch.int8 use_int8_dispatch = quant_config.quant_dtype == torch.int8
prepare_finalize = DeepEPLLPrepareAndFinalize( prepare_finalize = DeepEPLLPrepareAndFinalize(
......
...@@ -38,10 +38,10 @@ from vllm.utils.math_utils import cdiv, round_up ...@@ -38,10 +38,10 @@ from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.import_utils import has_deep_gemm from vllm.utils.import_utils import has_deep_gemm
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from lightop import fuse_silu_mul_quant_ep from lightop import fuse_silu_mul_quant_ep, fuse_silu_mul_fp8_quant_ep
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
if has_deep_gemm(): if has_deep_gemm():
from deepgemm import m_grouped_w8a8_gemm_nt_masked from deepgemm import m_grouped_w8a8_gemm_nt_masked, m_grouped_fp8_gemm_nt_masked
else: else:
from lightop import m_grouped_w8a8_gemm_nt_masked from lightop import m_grouped_w8a8_gemm_nt_masked
...@@ -452,8 +452,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -452,8 +452,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dispatchers=num_dispatchers, num_dispatchers=num_dispatchers,
) )
if quant_config.use_fp8_w8a8: #if quant_config.use_fp8_w8a8:
assert self.block_shape == get_mk_alignment_for_contiguous_layout() #assert self.block_shape == get_mk_alignment_for_contiguous_layout()
self.N = N self.N = N
self.K = K self.K = K
...@@ -606,7 +606,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -606,7 +606,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m = self.get_expected_m() expected_m = self.get_expected_m()
if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8: if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8:
fp8_m_grouped_gemm_nt_masked( m_grouped_fp8_gemm_nt_masked(
(a1q, a1q_scale), (a1q, a1q_scale),
(w1, self.w1_scale), (w1, self.w1_scale),
workspace1, workspace1,
...@@ -614,14 +614,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -614,14 +614,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m, expected_m,
) )
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle() a2q, a2q_scale = fuse_silu_mul_fp8_quant_ep(
a2q, a2q_scale = persistent_masked_m_silu_mul_quant( input=workspace1,
workspace1, fp8type=0,
expert_num_tokens, tokens_per_expert=expert_num_tokens,
quant_scale_fmt=quant_scale_fmt,
) )
fp8_m_grouped_gemm_nt_masked( m_grouped_fp8_gemm_nt_masked(
(a2q, a2q_scale), (a2q, a2q_scale),
(w2, self.w2_scale), (w2, self.w2_scale),
output, output,
......
...@@ -87,14 +87,14 @@ def _quant_flags_to_group_shape( ...@@ -87,14 +87,14 @@ def _quant_flags_to_group_shape(
""" """
a_shape: GroupShape | None a_shape: GroupShape | None
w_shape: GroupShape | None w_shape: GroupShape | None
if block_shape is not None and quant_dtype!=torch.int8: if block_shape is not None and quant_dtype!=torch.int8 and quant_dtype!=current_platform.fp8_dtype():
assert not per_act_token_quant assert not per_act_token_quant
assert not per_out_ch_quant assert not per_out_ch_quant
# TODO(bnell): this is not quite right for activations since first # TODO(bnell): this is not quite right for activations since first
# dim should be 1. # dim should be 1.
a_shape = GroupShape(row=block_shape[0], col=block_shape[1]) a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1]) w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
elif block_shape is not None and quant_dtype == torch.int8: elif block_shape is not None and (quant_dtype == torch.int8 or quant_dtype == current_platform.fp8_dtype()):
a_shape = GroupShape(row=block_shape[0], col=block_shape[1]) a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1]) w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
else: else:
...@@ -518,7 +518,7 @@ class FusedMoEQuantConfig: ...@@ -518,7 +518,7 @@ class FusedMoEQuantConfig:
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
), ),
) )
if quant_dtype != torch.int8: if quant_dtype != torch.int8 and quant_dtype != current_platform.fp8_dtype():
assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant assert quant_config.per_out_ch_quant == per_out_ch_quant
assert quant_config.block_shape == block_shape assert quant_config.block_shape == block_shape
......
...@@ -22,7 +22,7 @@ from vllm.v1.worker.ubatching import ( ...@@ -22,7 +22,7 @@ from vllm.v1.worker.ubatching import (
dbo_enabled, dbo_enabled,
dbo_maybe_run_recv_hook, dbo_maybe_run_recv_hook,
) )
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
# DeepEP kernels quantize dispatch inputs in 128 element chunks. # DeepEP kernels quantize dispatch inputs in 128 element chunks.
...@@ -179,7 +179,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -179,7 +179,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if quant_config.block_shape is not None if quant_config.block_shape is not None
else None else None
) )
if block_k == DEEPEP_QUANT_BLOCK_SIZE: if block_k == DEEPEP_QUANT_BLOCK_SIZE or (isinstance(x, tuple) and x[0].dtype == current_platform.fp8_dtype()):
# DeepEP kernels did the quantization for us. # DeepEP kernels did the quantization for us.
x, x_scales = x x, x_scales = x
return x, x_scales return x, x_scales
......
...@@ -6,7 +6,9 @@ import functools ...@@ -6,7 +6,9 @@ import functools
import json import json
import os import os
import math import math
import sys
import aiter
from aiter.moe import get_aiter_moe_config, aiter_moe, MoeQuantType
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
...@@ -70,6 +72,23 @@ from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_ ...@@ -70,6 +72,23 @@ from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_
from lightop import fuse_silu_and_mul from lightop import fuse_silu_and_mul
from lightop import op as op from lightop import op as op
try:
if envs.VLLM_ROCM_USE_AITER_MOE:
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
else:
raise Exception("VLLM_ROCM_USE_AITER_MOE not set.")
except Exception:
get_aiter_moe_config = None
aiter_moe = None
MoeQuantType = None
print("INFO: Please install aiter if you want to infer with aiter_moe.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
...@@ -1742,6 +1761,32 @@ def fused_experts_impl( ...@@ -1742,6 +1761,32 @@ def fused_experts_impl(
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_ROCM_USE_AITER_MOE and use_int4_w4a16 and hidden_states.dtype == torch.bfloat16 and get_aiter_moe_config is not None and aiter_moe is not None:
# 根据 aiter 的config 判断是否启用 aiter
M, K = hidden_states.shape
E, N1, _ = w1.shape
_, N2, _ = w2.shape
top_k_num = topk_ids.size(1)
status, moe_config = get_aiter_moe_config(
M=M, E=E, N1=N1, N2=N2, K=K,
top_k=top_k_num, block_size=block_shape[1], dtype=hidden_states.dtype,
quant_type=MoeQuantType.W4A16,
)
if not status:
logger.info_once(
f"[aiter_moe_w4a16] SKIP {M=}, {E=}, {N1=}, {N2=}, {K=}, {top_k_num=}, {block_shape=}: "
f"no backend available"
)
else:
is_inplace = inplace and not disable_inplace()
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, is_inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, global_num_experts, expert_map)
# return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
# block_shape, global_num_experts, expert_map, activation)
# Optional fast path: use Marlin W16A16 fused MoE implementation when the # Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout. # expert weights are already packed in Marlin layout.
if not use_nn_moe: if not use_nn_moe:
...@@ -1858,35 +1903,74 @@ def fused_experts_impl( ...@@ -1858,35 +1903,74 @@ def fused_experts_impl(
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype) cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
if use_int8_w8a8 or use_fp8_w8a8: if use_int8_w8a8 or use_fp8_w8a8:
return fused_experts_impl_int8(hidden_states=hidden_states, if envs.VLLM_USE_AITER_MOE_W8A8==True:
w1=w1, K_input = hidden_states.size(1)
w2=w2, actual_N2 = N // 2
topk_weights=topk_weights, quant_type = MoeQuantType.W8A8
topk_ids=topk_ids, status, moe_config = get_aiter_moe_config(
cache13=cache13, M=num_tokens,
inplace=inplace, E=global_num_experts,
activation=activation, N1=N,
apply_router_weight_on_input=apply_router_weight_on_input, N2=actual_N2,
use_fp8_w8a8=use_fp8_w8a8, K=K_input,
use_int8_w8a8=use_int8_w8a8, top_k=top_k_num,
use_int8_w8a16=False, block_size=0,
use_int4_w4a16=False, dtype=hidden_states.dtype,
per_channel_quant=per_channel_quant, quant_type=quant_type,
global_num_experts=global_num_experts, )
expert_map=expert_map,
w1_scale=w1_scale, output = aiter_moe(
w2_scale=w2_scale, hidden_states=hidden_states,
w1_zp=w1_zp, w1=w1,
w2_zp=w2_zp, w2=w2,
a1_scale=a1_scale, topk_weights=topk_weights,
a2_scale=a2_scale, topk_ids=topk_ids,
block_shape=block_shape, moe_config=moe_config,
use_nn_moe=False, inplace=inplace,
routed_scaling_factor=routed_scaling_factor, activation=activation,
shared_output=shared_output, w1_scale=w1_scale,
i_q=i_q, w2_scale=w2_scale,
i_s=i_s w1_zp=w1_zp,
) w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=None,
global_num_experts=global_num_experts,
expert_map=expert_map,
routed_scaling_factor=routed_scaling_factor,
)
return output
else:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
i_q=i_q,
i_s=i_s
)
elif use_int4_w4a8 is True: elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states, return fused_experts_impl_w4a8(hidden_states=hidden_states,
w1=w1, w1=w1,
......
...@@ -733,7 +733,8 @@ class FusedMoE(CustomOp): ...@@ -733,7 +733,8 @@ class FusedMoE(CustomOp):
if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod", if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod",
"SlimQuantW4A8Int8MoEMethod", "SlimQuantW4A8Int8MoEMethod",
"SlimQuantW4A8Int8MarlinMoEMethod")): "SlimQuantW4A8Int8MarlinMoEMethod",
"SlimQuantW4A8Int8AiterMoEMethod")):
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition
......
...@@ -45,7 +45,7 @@ QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) ...@@ -45,7 +45,7 @@ QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
DEPRECATED_QUANTIZATION_METHODS = [ DEPRECATED_QUANTIZATION_METHODS = [
"tpu_int8", "tpu_int8",
"ptpc_fp8", # "ptpc_fp8",
"fbgemm_fp8", "fbgemm_fp8",
"fp_quant", "fp_quant",
"bitblas", "bitblas",
......
...@@ -12,6 +12,7 @@ from compressed_tensors.quantization import ( ...@@ -12,6 +12,7 @@ from compressed_tensors.quantization import (
QuantizationStrategy, QuantizationStrategy,
) )
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -1806,6 +1807,36 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1806,6 +1807,36 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.a13_scale = None layer.a13_scale = None
layer.a2_scale = None layer.a2_scale = None
if envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once("Using aiter moe")
w1_zp = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // self.group_size // 2 if self.num_bits == 4 else hidden_size // self.group_size,
dtype=torch.uint8,
)
if self.num_bits == 4: w1_zp[:] = 136
w13_qzeros = torch.nn.Parameter(
w1_zp,
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_zp = torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size // self.group_size // 2 if self.num_bits == 4 else hidden_size // self.group_size,
dtype=torch.uint8,
)
if self.num_bits == 4: w2_zp[:] = 136
w2_qzeros = torch.nn.Parameter(
w2_zp,
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Reconfigure packed weights and scales to match moe_wna16 format # Reconfigure packed weights and scales to match moe_wna16 format
layer.w13_weight_packed = torch.nn.Parameter( layer.w13_weight_packed = torch.nn.Parameter(
...@@ -1836,8 +1867,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1836,8 +1867,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
return config_builder( return config_builder(
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
w1_zp=None, w1_zp=layer.w13_qzeros if envs.VLLM_ROCM_USE_AITER_MOE else None,
w2_zp=None, w2_zp=layer.w2_qzeros if envs.VLLM_ROCM_USE_AITER_MOE else None,
block_shape=[0, self.group_size], block_shape=[0, self.group_size],
) )
......
...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.utils.w8a8_utils import( from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight, weight8bit_nt_kpack2_marlin1) get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight, weight8bit_nt_kpack2_marlin1)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig, int8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig, int8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -26,6 +26,12 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -26,6 +26,12 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
import aiter
from aiter.test_common import checkAllclose, perftest
from aiter.ops.shuffle import moe_layout_shuffle_gemm1, moe_layout_shuffle_gemm2
from aiter.fused_moe import fused_topk, torch_moe
from aiter import dtypes, ActivationType
from aiter.moe import get_aiter_moe_config, aiter_moe, MoeSolutionType, MoeQuantType
try: try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
from lmslim.layers.fused_moe.fuse_moe_fp8_marlin import fused_experts_impl_fp8_marlin from lmslim.layers.fused_moe.fuse_moe_fp8_marlin import fused_experts_impl_fp8_marlin
...@@ -114,14 +120,38 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): ...@@ -114,14 +120,38 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
"dynamic per token quantization. Found static input scales.") "dynamic per token quantization. Found static input scales.")
self.fused_experts = self.fused_moe_forward self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
per_out_ch_quant=False,
block_shape=[256, 256] if self.use_deepep else None,
)
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
# WEIGHTS # WEIGHTS
...@@ -169,23 +199,51 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): ...@@ -169,23 +199,51 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def shuffle_w8a8_gemm1(self, weight_data):
w_fp8 = weight_data.to(torch.float8_e4m3fn)
shuffled = moe_layout_shuffle_gemm1(w_fp8)
return shuffled.view(torch.int8)
def shuffle_w8a8_gemm2(self, weight_data):
w_fp8 = weight_data.to(torch.float8_e4m3fn)
shuffled = moe_layout_shuffle_gemm2(w_fp8)
return shuffled.view(torch.int8)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = [] if envs.VLLM_USE_AITER_MOE_W8A8==True:
for ii in range(layer.w13_weight.shape[0]): layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, requires_grad=False)
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data, requires_grad=False)
w1_marlin_list.append(w1_marlin_in.float() if w1_marlin_in.dtype == torch.float8_e4m3fn else w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0) shuffled_w13 = self.shuffle_w8a8_gemm1(layer.w13_weight)
w1_marlin = fp32_to_fp8_e4m3fn(w1_marlin) w13_data = shuffled_w13.view(*layer.w13_weight.shape).view(torch.int8)
layer.w13_weight = Parameter(w13_data, requires_grad=False)
del w1_marlin_list
w2_marlin_list = [] shuffled_w2 = self.shuffle_w8a8_gemm2(layer.w2_weight)
for ii in range(layer.w2_weight.shape[0]): w2_data = shuffled_w2.view(*layer.w2_weight.shape).view(torch.int8)
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii]) layer.w2_weight = Parameter(w2_data, requires_grad=False)
w2_marlin_list.append(w2_marlin_in.float() if w2_marlin_in.dtype == torch.float8_e4m3fn else w2_marlin_in) else:
w2_marlin = torch.stack(w2_marlin_list, dim=0) w1_marlin_list = []
w2_marlin = fp32_to_fp8_e4m3fn(w2_marlin) for ii in range(layer.w13_weight.shape[0]):
layer.w13_weight = Parameter(w1_marlin, requires_grad=False) if not self.use_deepep:
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in.float() if w1_marlin_in.dtype == torch.float8_e4m3fn else w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
w1_marlin = fp32_to_fp8_e4m3fn(w1_marlin)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in.float() if w2_marlin_in.dtype == torch.float8_e4m3fn else w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
w2_marlin = fp32_to_fp8_e4m3fn(w2_marlin)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def fused_moe_forward( def fused_moe_forward(
self, self,
...@@ -200,27 +258,66 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): ...@@ -200,27 +258,66 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
): ):
if envs.VLLM_USE_AITER_MOE_W8A8==True:
return fused_experts_impl_fp8_marlin( m_flat = x.view(-1, x.shape[-1])
hidden_states=x, M = m_flat.shape[0]
w1=layer.w13_weight, E = layer.w13_weight.size(0)
w2=layer.w2_weight, K = x.size(-1)
topk_weights=topk_weights, N1 = layer.w13_weight.size(1)
topk_ids=topk_ids, topk = topk_ids.size(1)
inplace=True, w1_input = layer.w13_weight.view(E, N1, K)
activation=activation, w2_input = layer.w2_weight.view(E, K, N1 // 2)
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True, _, moe_cfg = get_aiter_moe_config(
per_channel_quant=True, M=M,
global_num_experts=global_num_experts, E=E,
expert_map=expert_map, N1=N1,
w1_scale=layer.w13_weight_scale, N2=N1 // 2,
w2_scale=layer.w2_weight_scale, K=K,
a1_scale=layer.w13_input_scale, top_k=topk,
a2_scale=layer.w2_input_scale, block_size=0,
use_nn_moe=False, dtype=x.dtype,
shared_output=shared_output, quant_type=MoeQuantType.W8A8,
routed_scaling_factor=routed_scaling_factor) )
output = aiter_moe(
hidden_states=x,
w1=w1_input,
w2=w2_input,
topk_weights=topk_weights,
topk_ids=topk_ids,
moe_config=moe_cfg,
inplace=False,
activation=getattr(layer, "activation", "silu"),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=getattr(layer, "w13_input_scale", None),
a2_scale=getattr(layer, "w2_input_scale", None),
global_num_experts=E,
expert_map=getattr(layer, "expert_map", None),
routed_scaling_factor=routed_scaling_factor,
)
return output
else:
return fused_experts_impl_fp8_marlin(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def apply( def apply(
self, self,
...@@ -261,6 +358,42 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): ...@@ -261,6 +358,42 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output, ) shared_output=shared_output, )
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts,
)
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
return BatchedDeepGemmExperts(
moe_config=self.moe,
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
N=self.N,
K=self.K
)
else:
logger.debug("DeepGemmExperts(%s)", self.__class__.__name__)
return DeepGemmExperts(moe_config=self.moe,
quant_config=self.moe_quant_config,
N=self.N,
K=self.K)
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__( def __init__(
self, self,
...@@ -369,28 +502,44 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -369,28 +502,44 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def shuffle_w8a8_gemm1(self, weight_data):
w_i8 = weight_data.to(torch.int8)
return moe_layout_shuffle_gemm1(w_i8)
def shuffle_w8a8_gemm2(self, weight_data):
w_i8 = weight_data.to(torch.int8)
return moe_layout_shuffle_gemm2(w_i8)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = [] if envs.VLLM_USE_AITER_MOE_W8A8==True:
for ii in range(layer.w13_weight.shape[0]): layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, requires_grad=False)
if not self.use_deepep: layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data, requires_grad=False)
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) shuffled_w13 = self.shuffle_w8a8_gemm1(layer.w13_weight)
else: layer.w13_weight = Parameter(shuffled_w13.view(*layer.w13_weight.shape), requires_grad=False)
w1_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w13_weight[ii]) shuffled_w2 = self.shuffle_w8a8_gemm2(layer.w2_weight)
w1_marlin_list.append(w1_marlin_in) layer.w2_weight = Parameter(shuffled_w2.view(*layer.w2_weight.shape), requires_grad=False)
w1_marlin = torch.stack(w1_marlin_list, dim=0) else:
w1_marlin_list = []
del w1_marlin_list for ii in range(layer.w13_weight.shape[0]):
w2_marlin_list = [] if not self.use_deepep:
for ii in range(layer.w2_weight.shape[0]): w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
if not self.use_deepep: else:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii]) w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
else: w1_marlin_list.append(w1_marlin_in)
w2_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w2_weight[ii]) w1_marlin = torch.stack(w1_marlin_list, dim=0)
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0) del w1_marlin_list
w2_marlin_list = []
layer.w13_weight = Parameter(w1_marlin, requires_grad=False) for ii in range(layer.w2_weight.shape[0]):
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def apply( def apply(
self, self,
...@@ -406,30 +555,70 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -406,30 +555,70 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts_impl_int8_marlin( if envs.VLLM_USE_AITER_MOE_W8A8==True:
hidden_states=x, m_flat = x.view(-1, x.shape[-1])
w1=layer.w13_weight, M = m_flat.shape[0]
w2=layer.w2_weight, E = layer.w13_weight.size(0)
topk_weights=topk_weights, K = x.size(-1)
topk_ids=topk_ids, N1 = layer.w13_weight.size(1)
inplace=True, topk = topk_ids.size(1)
activation=layer.activation, w1_input = layer.w13_weight.view(E, N1, K)
apply_router_weight_on_input=layer.apply_router_weight_on_input, w2_input = layer.w2_weight.view(E, K, N1 // 2)
use_int8_w8a8=True,
per_channel_quant=True, _, moe_cfg = get_aiter_moe_config(
global_num_experts=layer.global_num_experts, M=M,
expert_map=layer.expert_map, E=E,
quant_config=self.moe_quant_config, N1=N1,
w1_scale=layer.w13_weight_scale, N2=N1 // 2,
w2_scale=layer.w2_weight_scale, K=K,
a1_scale=layer.w13_input_scale, top_k=topk,
a2_scale=layer.w2_input_scale, block_size=0,
use_nn_moe=False, dtype=x.dtype,
i_q=i_q, quant_type=MoeQuantType.W8A8,
i_s=i_s, )
shared_output=shared_output, output = aiter_moe(
routed_scaling_factor=routed_scaling_factor, hidden_states=x,
) w1=w1_input,
w2=w2_input,
topk_weights=topk_weights,
topk_ids=topk_ids,
moe_config=moe_cfg,
inplace=False,
activation=getattr(layer, "activation", "silu"),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=getattr(layer, "w13_input_scale", None),
a2_scale=getattr(layer, "w2_input_scale", None),
global_num_experts=E,
expert_map=getattr(layer, "expert_map", None),
routed_scaling_factor=routed_scaling_factor,
)
return output
else:
return fused_experts_impl_int8_marlin(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def select_gemm_impl( def select_gemm_impl(
self, self,
......
...@@ -52,7 +52,7 @@ class QuantFP8(CustomOp): ...@@ -52,7 +52,7 @@ class QuantFP8(CustomOp):
column major format column major format
:param compile_native: Manually compile forward_native if compile mode > None :param compile_native: Manually compile forward_native if compile mode > None
""" """
super().__init__(compile_native=compile_native) super().__init__(compile_native=compile_native, enforce_enable=True)
self.static = static self.static = static
self.group_shape = group_shape self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
......
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.fp8 import ( ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config, Fp8Config,
Fp8KVCacheMethod, Fp8KVCacheMethod,
Fp8LinearMethod, Fp8LinearMethod,
Fp8OnlineLinearMethod,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
...@@ -42,10 +43,10 @@ class PTPCFp8Config(Fp8Config): ...@@ -42,10 +43,10 @@ class PTPCFp8Config(Fp8Config):
if not current_platform.is_rocm(): if not current_platform.is_rocm():
raise ValueError("ptpc_fp8 quantization is supported only on ROCm.") raise ValueError("ptpc_fp8 quantization is supported only on ROCm.")
if not current_platform.has_device_capability(94): # if not current_platform.has_device_capability(94):
raise ValueError( # raise ValueError(
"ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 # "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501
) # )
if activation_scheme == "static": if activation_scheme == "static":
raise ValueError("ptpc_fp8 as of now only support dynamic quantization.") raise ValueError("ptpc_fp8 as of now only support dynamic quantization.")
...@@ -77,7 +78,7 @@ class PTPCFp8Config(Fp8Config): ...@@ -77,7 +78,7 @@ class PTPCFp8Config(Fp8Config):
return None return None
class PTPCFp8LinearMethod(Fp8LinearMethod): class PTPCFp8LinearMethod(Fp8OnlineLinearMethod):
"""Linear method for Per-Token and Per-Channel FP8 Quantization. """Linear method for Per-Token and Per-Channel FP8 Quantization.
Only supports loading quantized BF16 model checkpoints with dynamic Only supports loading quantized BF16 model checkpoints with dynamic
activation scaling. To load FP16 model checkpoints, user must specify activation scaling. To load FP16 model checkpoints, user must specify
...@@ -114,13 +115,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): ...@@ -114,13 +115,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
if layer.weight.data.dtype == torch.bfloat16: if layer.weight.data.dtype == torch.bfloat16:
# Quantize the weights. # Quantize the weights.
qweight, weight_scale = ops.scaled_fp8_quant( qweight, weight_scale = ops.scaled_fp8_quant_weight(
layer.weight, scale=None, use_per_token_if_dynamic=True layer.weight, scale=None, use_per_token_if_dynamic=True
) )
# Update the layer with the new values. # Update the layer with the new values.
layer.weight = Parameter( layer.weight = Parameter(
qweight.t(), requires_grad=False qweight.contiguous(), requires_grad=False
) # Pretranspose the weight ) # Pretranspose the weight
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else: else:
......
...@@ -25,6 +25,21 @@ import os ...@@ -25,6 +25,21 @@ import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from aiter.ops.shuffle import w4a8_moe_layout_shuffle_gemm1,w4a8_moe_layout_shuffle_gemm2
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter import dtypes, ActivationType
except ImportError as e:
print("Import error msg: import aiter")
W8A8_TRITONJSON=W8a8GetCacheJSON() W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor, def baseline_scaled_mm(a: torch.Tensor,
...@@ -82,7 +97,10 @@ class SlimQuantW4A8Int8Config(QuantizationConfig): ...@@ -82,7 +97,10 @@ class SlimQuantW4A8Int8Config(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MoEMethod(self, layer.moe_config) if envs.VLLM_ROCM_USE_AITER_MOE:
return SlimQuantW4A8Int8AiterMoEMethod(self, layer.moe_config)
else:
return SlimQuantW4A8Int8MoEMethod(self, layer.moe_config)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -328,4 +346,209 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -328,4 +346,209 @@ class SlimQuantW4A8Int8MoEMethod:
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
)
class SlimQuantW4A8Int8AiterMoEMethod:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config, moe):
self.moe = moe
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.moe_mk: Optional[FusedMoEModularKernel] = None
def get_fused_moe_quant_config(
self, layer: torch.nn.Module)-> Optional[FusedMoEQuantConfig]:
self.moe_quant_config = FusedMoEQuantConfig.make(
torch.int8,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
per_out_ch_quant=False,
block_shape=None,
weight_dtype='int4'
)
return self.moe_quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
) )
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def repack_and_shuffle_w4a8(self, weight_data, E):
"""
逐 expert 处理 [n, k_half]
处理完直接写回 weight_data[i]
"""
# 原始 shape: [E, n, k_half]
for i in range(E):
# 1. 取当前 expert [n, k_half]
expert = weight_data[i]
n, k_half = expert.shape
# 2. repack 逻辑(连续 → blocked)
w_u8 = expert.to(torch.uint8)
# 解包 1byte → 2个4bit
w_unpacked = torch.stack([
(w_u8 >> 4) & 0x0F,
w_u8 & 0x0F
], dim=-1).view(n, -1)
# 8个4bit分块重排
blocks = w_unpacked.view(n, -1, 8)
w_low = blocks[..., :4]
w_high = blocks[..., 4:]
packed = (w_low << 4) | w_high
packed = packed.view(n, k_half)
# 3. shuffle
w_marlin_in = w4a8_moe_layout_shuffle_gemm2(packed)
w_marlin_in = w_marlin_in.reshape(n, k_half)
# 4. 直接写回
weight_data[i] = w_marlin_in
return weight_data
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
layer.w13_weight = Parameter(self.repack_and_shuffle_w4a8(layer.w13_weight.data, E), requires_grad=False)
layer.w2_weight = Parameter(self.repack_and_shuffle_w4a8(layer.w2_weight.data, E), requires_grad=False)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
E = layer.w13_weight.size(0)
K = x.size(-1)
N1 = layer.w13_weight.size(1)
if x.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
M = x.size(0)
else:
assert x.dim() == 3
assert x.size(0) == E, f"{x.size(0)} == {E}"
M = x.size(1)
topk = topk_ids.size(1)
status, moe_cfg = get_aiter_moe_config(
M=M,
E=E,
N1=N1,
N2=N1//2,
K=K,
top_k=topk,
block_size=None,
dtype=dtypes.bf16,
quant_type=MoeQuantType.W4A8,
)
if not status:
assert moe_cfg.solution_type is None
assert moe_cfg.config is None
logger.info(f"[get_config_w4a8] {M=}, no solution found")
return aiter_moe(
x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
moe_config=moe_cfg,
activation="silu",
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
global_num_experts=E,
expert_map=None,
)
\ No newline at end of file
...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, ...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod, SlimQuantW4A8Int8AiterMoEMethod
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
try: try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
...@@ -111,7 +111,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): ...@@ -111,7 +111,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MarlinMoEMethod(self, layer.moe_config) if envs.VLLM_ROCM_USE_AITER_MOE:
return SlimQuantW4A8Int8AiterMoEMethod(self, layer.moe_config)
else:
return SlimQuantW4A8Int8MarlinMoEMethod(self, layer.moe_config)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
......
...@@ -1515,11 +1515,6 @@ class GPUModelRunner( ...@@ -1515,11 +1515,6 @@ class GPUModelRunner(
tp_size: int, tp_size: int,
req_ids: list[str], req_ids: list[str],
): ):
tokens_per_rank = (total_q_len + tp_size - 1) // tp_size
start_token = tp_rank * tokens_per_rank
end_token = min((tp_rank + 1) * tokens_per_rank, total_q_len)
q_lens = [] q_lens = []
seq_count = 0 seq_count = 0
seq_indexes = [] seq_indexes = []
...@@ -1607,6 +1602,10 @@ class GPUModelRunner( ...@@ -1607,6 +1602,10 @@ class GPUModelRunner(
if isinstance(rank_tokens, torch.Tensor): if isinstance(rank_tokens, torch.Tensor):
rank_tokens = rank_tokens.item() rank_tokens = rank_tokens.item()
else: else:
tokens_per_rank = (total_q_len + tp_size - 1) // tp_size
start_token = tp_rank * tokens_per_rank
end_token = min((tp_rank + 1) * tokens_per_rank, total_q_len)
current_seq = 0 current_seq = 0
current_pos = 0 current_pos = 0
rank_tokens = min(tokens_per_rank, end_token - start_token) rank_tokens = min(tokens_per_rank, end_token - start_token)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
from ._version import __version__, __version_tuple__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
For example - return True if the current version if 0.7.4 and the
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number."""
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment