Unverified Commit f6227c22 authored by czhu-cohere's avatar czhu-cohere Committed by GitHub
Browse files

[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)


Signed-off-by: default avatarczhu-cohere <conway.zhu@cohere.com>
parent ea657f20
...@@ -6,7 +6,11 @@ import torch ...@@ -6,7 +6,11 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
...@@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel): ...@@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
"CUTLASS W4A8, only supported int4", "CUTLASS W4A8, only supported int4",
) )
# TODO(czhu): support -1 (column-wise)
if c.group_size != 128: if c.group_size != 128:
return False, "Only group_size 128 is supported" return False, "Only group_size 128 is supported"
...@@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel): ...@@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1} # `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module): def process_weights_after_loading(self, layer: torch.nn.Module):
# TODO(czhu): optimize speed/mem usage
def transform_w_q(x): def transform_w_q(x):
assert isinstance(x, BasevLLMParameter) assert isinstance(x, BasevLLMParameter)
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t()) x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
return x return x
...@@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel): ...@@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
x.data = ops.cutlass_pack_scale_fp8(x.data) x.data = ops.cutlass_pack_scale_fp8(x.data)
return x return x
w_s = getattr(layer, self.w_s_name)
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
w_s.data = fp8_scales
# register per-channel scales
layer.register_parameter(
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
)
# Encode/reorder weights and pack scales # Encode/reorder weights and pack scales
self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s) self._transform_param(layer, self.w_s_name, transform_w_s)
self._transform_param(layer, "weight_chan_scale", lambda x: x)
def apply_weights( def apply_weights(
self, self,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from collections.abc import Mapping from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from types import MappingProxyType from types import MappingProxyType
from typing import ClassVar, NamedTuple from typing import ClassVar, NamedTuple
...@@ -691,3 +691,51 @@ def cutlass_fp4_supported() -> bool: ...@@ -691,3 +691,51 @@ def cutlass_fp4_supported() -> bool:
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int() capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability) return cutlass_scaled_mm_supports_fp4(capability)
def convert_bf16_scales_to_fp8(
quant_fp8: Callable, scales: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales)
expected by W4A8 GEMM kernels.
"""
assert scales.is_contiguous(), (
f"scale tensor must be contiguous, got {scales.stride()=}"
)
assert scales.is_cuda, "scales must be on gpu"
orig_shape = scales.shape
k_groups = orig_shape[-1]
flat_scales = scales.view(-1, k_groups)
fp8_scales, chan_scales = quant_fp8(flat_scales)
fp8_scales = (fp8_scales.float() / 8.0).to(torch.float8_e4m3fn)
chan_scales *= 8.0
# restore original shape
fp8_scales = fp8_scales.view(orig_shape)
chan_scales = chan_scales.view(orig_shape[:-1], -1)
return fp8_scales, chan_scales
def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tensor:
"""
Convert int4b8 (packed to int32) to signed int4
"""
assert t.is_cuda, "tensor must be on gpu"
assert t.dtype == torch.int32, f"expected int32 packed weights but got {t.dtype}"
# loop through the 8 4-bit nibbles in each int32 entry
for i in range(8):
shift = 4 * i
# extract the i-th 4-bit nibble
nib = (t >> shift) & 0xF
# clear the original nibble by masking out
t &= ~(0xF << shift)
# convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8
# and update in-place
t |= ((nib - 8) & 0xF) << shift
return t
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment