"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "5a468ff7c7463a9ed4e6353481a6c0dfb2bfa100"
Unverified Commit 5f7fab88 authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[ROCm][FEAT] Integrate aiter gemm w8a8 ptpc (#33773)


Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent 343f6523
......@@ -1856,6 +1856,7 @@ class TestFP8Layer(torch.nn.Module):
out_dtype=out_dtype,
force_kernel=force_kernel,
)
self.kernel.process_weights_after_loading(self)
def is_quant_fp8_enabled(self) -> bool:
return self.kernel.quant_fp8.enabled()
......
......@@ -3,6 +3,7 @@
import functools
from collections.abc import Callable
import pandas as pd
import torch
from torch._ops import OpOverload
......@@ -56,6 +57,29 @@ def is_aiter_found_and_supported() -> bool:
return False
@functools.cache
def _load_gemm_tuned_configs(
q_dtype_w: torch.dtype, csv_path: str
) -> set[tuple[int, int, int]]:
try:
df = pd.read_csv(csv_path).drop_duplicates()
df = df[df["q_dtype_w"] == str(q_dtype_w)]
return set(zip(df["N"].astype(int), df["K"].astype(int), df["M"].astype(int)))
except Exception:
return set()
def _check_kernel_tuned(N: int, K: int, q_dtype_w: torch.dtype, csv_path: str) -> bool:
configs = _load_gemm_tuned_configs(q_dtype_w, csv_path)
l_m = (
[1, 2, 4]
+ list(range(8, 513, 8))
+ [1024, 1536]
+ [2**i for i in range(11, 19)]
)
return any((N, K, M) in configs for M in l_m)
def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
ROCm AITER package is supported and enabled on gfx9 archs.
......@@ -468,7 +492,7 @@ def _rocm_aiter_mla_decode_fwd_fake(
pass
def _rocm_aiter_gemm_a8w8_impl(
def _rocm_aiter_w8a8_gemm_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
......@@ -485,7 +509,7 @@ def _rocm_aiter_gemm_a8w8_impl(
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def _rocm_aiter_gemm_a8w8_fake(
def _rocm_aiter_w8a8_gemm_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
......@@ -499,6 +523,35 @@ def _rocm_aiter_gemm_a8w8_fake(
return Y
def _rocm_aiter_preshuffled_per_token_w8a8_gemm_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_bpreshuffle
output = gemm_a8w8_bpreshuffle(A, B, As, Bs, None, output_dtype)
if bias is not None:
output.add_(bias)
return output
def _rocm_aiter_preshuffled_per_token_w8a8_gemm_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
return torch.empty(m, n, dtype=output_dtype, device=A.device)
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
......@@ -1313,11 +1366,15 @@ class rocm_aiter_ops:
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8",
op_func=_rocm_aiter_gemm_a8w8_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_a8w8_fake,
dispatch_key=current_platform.dispatch_key,
op_name="rocm_aiter_w8a8_gemm",
op_func=_rocm_aiter_w8a8_gemm_impl,
fake_impl=_rocm_aiter_w8a8_gemm_fake,
)
direct_register_custom_op(
op_name="_rocm_aiter_preshuffled_per_token_w8a8_gemm",
op_func=_rocm_aiter_preshuffled_per_token_w8a8_gemm_impl,
fake_impl=_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake,
)
direct_register_custom_op(
......@@ -1493,7 +1550,7 @@ class rocm_aiter_ops:
)
@staticmethod
def gemm_a8w8(
def w8a8_gemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
......@@ -1501,7 +1558,20 @@ class rocm_aiter_ops:
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
return torch.ops.vllm.rocm_aiter_w8a8_gemm(A, B, As, Bs, bias, output_dtype)
@staticmethod
def preshuffled_per_token_w8a8_gemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm._rocm_aiter_preshuffled_per_token_w8a8_gemm(
A, B, As, Bs, bias, output_dtype
)
@staticmethod
def triton_gemm_a8w8_blockscale(
......@@ -1920,6 +1990,24 @@ class rocm_aiter_ops:
(8192, 3584),
]
@staticmethod
def is_shuffled_per_token_w8a8_gemm_tuned(
N: int, K: int, q_dtype_w: torch.dtype
) -> bool:
import aiter.ops.gemm_op_a8w8 as aiter_gemm_a8w8_ops
csv_path = (
aiter_gemm_a8w8_ops.AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE
)
return _check_kernel_tuned(N, K, q_dtype_w, csv_path)
@staticmethod
def is_per_token_w8a8_gemm_tuned(N: int, K: int, q_dtype_w: torch.dtype) -> bool:
import aiter.ops.gemm_op_a8w8 as aiter_gemm_a8w8_ops
csv_path = aiter_gemm_a8w8_ops.AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_FILE
return _check_kernel_tuned(N, K, q_dtype_w, csv_path)
@staticmethod
def shuffle_weight(
tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
......
......@@ -106,6 +106,8 @@ from vllm.model_executor.kernels.linear.scaled_mm import (
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterFp8BlockScaledMMKernel,
AiterInt8ScaledMMLinearKernel,
AiterPerTokenFp8ScaledMMLinearKernel,
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel,
......@@ -165,6 +167,8 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel,
AiterPerTokenFp8ScaledMMLinearKernel,
ROCmFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
......@@ -360,18 +364,18 @@ def choose_scaled_mm_linear_kernel(
def init_fp8_linear_kernel(
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
weight_shape: tuple[int, int],
input_dtype: torch.dtype,
out_dtype: torch.dtype,
force_kernel: type[_KernelT] | None = None,
weight_shape: tuple[int, int],
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
module_name: str | None = None,
) -> FP8ScaledMMLinearKernel | Fp8BlockScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
weight_shape=weight_shape,
input_dtype=input_dtype,
out_dtype=out_dtype,
weight_shape=weight_shape,
)
if activation_quant_key.scale.group_shape.is_per_group():
......@@ -725,6 +729,8 @@ __all__ = [
"FP8ScaledMMLinearLayerConfig",
"Int8ScaledMMLinearLayerConfig",
"ScaledMMLinearLayerConfig",
"AiterPreshuffledPerTokenFp8ScaledMMLinearKernel",
"AiterPerTokenFp8ScaledMMLinearKernel",
"NvFp4LinearKernel",
"NvFp4LinearLayerConfig",
"AiterInt8ScaledMMLinearKernel",
......
......@@ -5,18 +5,27 @@
import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm._aiter_ops import (
rocm_aiter_ops,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from .BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearLayerConfig,
)
logger = init_logger(__name__)
class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
......@@ -113,7 +122,154 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
# a to be [M, K]
# b to be [N, K]
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
return rocm_aiter_ops.w8a8_gemm(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
class AiterPreshuffledPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "requires ROCm."
if not rocm_aiter_ops.is_linear_fp8_enabled():
return (
False,
"requires setting `VLLM_ROCM_USE_AITER=1` "
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
)
try:
import aiter # noqa: F401
except Exception:
return False, "requires aiter library to be installed."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
is_ptpc = (
c.activation_quant_key.scale.group_shape.is_per_token()
and c.weight_quant_key.scale.group_shape.is_per_channel()
)
if c.weight_shape is None:
return False, "weight_shape is required for Aiter kernels"
N, K = c.weight_shape
fp8_dtype = current_platform.fp8_dtype()
if c.out_dtype is not torch.bfloat16:
return False, "requires bfloat16 output dtype."
if not is_ptpc:
return (
False,
"requires per token activation scales and per channel weight scales.",
)
if not (N % 16 == 0 and K % 16 == 0):
return (
False,
f"requires N and K dimensions divisible by 16, received "
f"N={N} and K={K}.",
)
# Aiter's shuffled per-token Gemm performs better than torch only when its
# tuned.
if not rocm_aiter_ops.is_shuffled_per_token_w8a8_gemm_tuned(N, K, fp8_dtype):
return (
False,
f"requires a tuned configuration for N: {N} and K: {K} "
f"and fp8 dtype {fp8_dtype}.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_name, *_ = self.layer_param_names
w, *_ = self._get_layer_params(layer)
replace_parameter(
layer,
w_name,
torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight(w.t().contiguous()).data,
requires_grad=False,
),
)
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
return rocm_aiter_ops.preshuffled_per_token_w8a8_gemm(
A, B, As, Bs, bias, out_dtype
)
class AiterPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
return AiterPreshuffledPerTokenFp8ScaledMMLinearKernel.is_supported(
compute_capability
)
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
is_ptpc = (
c.activation_quant_key.scale.group_shape.is_per_token()
and c.weight_quant_key.scale.group_shape.is_per_channel()
)
if c.weight_shape is None:
return False, "weight_shape is required for Aiter kernels"
N, K = c.weight_shape
fp8_dtype = current_platform.fp8_dtype()
if not is_ptpc:
return (
False,
"requires per token activation scales and per channel weight scales.",
)
# Aiter's per-token Gemm performs better than torch oonly when its
# tuned.
if not rocm_aiter_ops.is_per_token_w8a8_gemm_tuned(N, K, fp8_dtype):
return (
False,
f"requires a tuned configuration for N: {N} and K: {K} "
f"and fp8 dtype {fp8_dtype}.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_name, *_ = self.layer_param_names
w, *_ = self._get_layer_params(layer)
replace_parameter(
layer,
w_name,
torch.nn.Parameter(w.t(), requires_grad=False),
)
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
return rocm_aiter_ops.w8a8_gemm(A, B, As, Bs, bias, out_dtype)
class AiterFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
......
......@@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
kFp8DynamicTokenSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
......@@ -47,7 +47,7 @@ activation_quant_key_mapping = {
DYNAMIC_QUANT: kFp8DynamicTokenSym,
}
weight_quant_key_mapping = {
QuantizationStrategy.CHANNEL: kFp8StaticTokenSym,
QuantizationStrategy.CHANNEL: kFp8StaticChannelSym,
QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
}
logger = init_logger(__name__)
......@@ -67,7 +67,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(*self.weight_block_size)
)
......@@ -76,7 +75,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
)
else:
self.activation_quant_key = activation_quant_key_mapping[
is_static_input_scheme
self.is_static_input_scheme
]
self.weight_quant_key = weight_quant_key_mapping[self.strategy]
......@@ -138,9 +137,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
weight_shape=(output_size_per_partition, input_size_per_partition),
module_name=self.__class__.__name__,
)
......
......@@ -175,6 +175,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
# Activations not quantized for marlin.
del layer.input_scale_ub
self.fp8_linear.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
......
......@@ -397,8 +397,6 @@ class Fp8LinearMethod(LinearMethodBase):
if self.block_quant:
assert not self.act_q_static
self.fp8_linear.process_weights_after_loading(layer)
# If checkpoint not serialized fp8, quantize the weights.
else:
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
......@@ -428,6 +426,8 @@ class Fp8LinearMethod(LinearMethodBase):
else:
layer.input_scale = None
self.fp8_linear.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
......
......@@ -517,6 +517,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
self.fp8_linear.process_weights_after_loading(layer)
def apply(
self,
......@@ -597,6 +598,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
self.fp8_linear.process_weights_after_loading(layer)
def apply(
self,
......
......@@ -120,6 +120,8 @@ class QuarkW8A8Fp8(QuarkScheme):
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
self.fp8_linear.process_weights_after_loading(layer)
def create_weights(
self,
layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment