Unverified Commit f6227c22 authored by czhu-cohere's avatar czhu-cohere Committed by GitHub
Browse files

[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)


Signed-off-by: default avatarczhu-cohere <conway.zhu@cohere.com>
parent ea657f20
......@@ -6,7 +6,11 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
......@@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
"CUTLASS W4A8, only supported int4",
)
# TODO(czhu): support -1 (column-wise)
if c.group_size != 128:
return False, "Only group_size 128 is supported"
......@@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
# TODO(czhu): optimize speed/mem usage
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
return x
......@@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
x.data = ops.cutlass_pack_scale_fp8(x.data)
return x
w_s = getattr(layer, self.w_s_name)
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
w_s.data = fp8_scales
# register per-channel scales
layer.register_parameter(
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
)
# Encode/reorder weights and pack scales
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
self._transform_param(layer, "weight_chan_scale", lambda x: x)
def apply_weights(
self,
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from types import MappingProxyType
from typing import ClassVar, NamedTuple
......@@ -691,3 +691,51 @@ def cutlass_fp4_supported() -> bool:
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
def convert_bf16_scales_to_fp8(
quant_fp8: Callable, scales: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales)
expected by W4A8 GEMM kernels.
"""
assert scales.is_contiguous(), (
f"scale tensor must be contiguous, got {scales.stride()=}"
)
assert scales.is_cuda, "scales must be on gpu"
orig_shape = scales.shape
k_groups = orig_shape[-1]
flat_scales = scales.view(-1, k_groups)
fp8_scales, chan_scales = quant_fp8(flat_scales)
fp8_scales = (fp8_scales.float() / 8.0).to(torch.float8_e4m3fn)
chan_scales *= 8.0
# restore original shape
fp8_scales = fp8_scales.view(orig_shape)
chan_scales = chan_scales.view(orig_shape[:-1], -1)
return fp8_scales, chan_scales
def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tensor:
"""
Convert int4b8 (packed to int32) to signed int4
"""
assert t.is_cuda, "tensor must be on gpu"
assert t.dtype == torch.int32, f"expected int32 packed weights but got {t.dtype}"
# loop through the 8 4-bit nibbles in each int32 entry
for i in range(8):
shift = 4 * i
# extract the i-th 4-bit nibble
nib = (t >> shift) & 0xF
# clear the original nibble by masking out
t &= ~(0xF << shift)
# convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8
# and update in-place
t |= ((nib - 8) & 0xF) << shift
return t
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment