Unverified Commit a0b26701 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Transform] Deterministic Hadacore Transforms (#24106)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent c4afdb69
......@@ -783,6 +783,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
# Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}")
if(HADACORE_ARCHS)
set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${HADACORE_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building hadacore")
endif()
# if CUDA endif
endif()
......
......@@ -347,6 +347,8 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace);
#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size = std::nullopt);
......
This diff is collapsed.
......@@ -613,6 +613,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
// Hadamard transforms
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
#ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
ops.def(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
from compressed_tensors.transform import deterministic_hadamard_matrix
from vllm import _custom_ops as ops
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("hidden_dim", [2**n for n in range(10)])
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
x = torch.eye(hidden_dim, dtype=dtype, device=device)
hadamard = deterministic_hadamard_matrix(
hidden_dim, dtype=torch.float64, device="cuda") / math.sqrt(hidden_dim)
y = ops.hadacore_transform(x.clone())
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)
assert torch.allclose(y, y_true)
y = ops.hadacore_transform(y)
assert torch.allclose(y, x)
......@@ -2011,3 +2011,27 @@ def onednn_scaled_mm(
input_zp_adj, bias, dnnl_handler.handler)
return output
def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832)
kernels. Note that these kernels exploit the recursive properties of
Sylvester Hadamards, and therefore do not require transform weight data
Note that sylvester hadamard transforms are also symmetric, which means that
this function is also applies the (transpose <=> inverse) transform.
:param x: value to be transformed inplace
:param inplace: modify value in place
:return: value after transformation
"""
return torch.ops._C.hadacore_transform(x, inplace)
if hasattr(torch.ops._C, "hadacore_transform"):
@register_fake("_C::hadacore_transform")
def _hadacore_transform_fake(x: torch.Tensor,
inplace: bool) -> torch.Tensor:
return torch.empty_like(x) if not inplace else x
......@@ -129,7 +129,7 @@ class CompressedTensorsConfig(QuantizationConfig):
# choose transform method
if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, input_tfms, output_tfms)
quant_method, quant_scheme, input_tfms, output_tfms)
else:
return quant_method
......
......@@ -12,6 +12,8 @@ from compressed_tensors.utils import is_match
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
LinearMethodBase,
QKVCrossParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
HadamardTransform)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
......@@ -26,14 +28,22 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
@classmethod
def from_schemes(
cls, quant_method: LinearMethodBase, input_tfms: dict[int,
TransformTuple],
output_tfms: dict[int, TransformTuple]
cls,
quant_method: LinearMethodBase,
quant_scheme: Optional[CompressedTensorsScheme],
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple],
) -> "CompressedTensorsLinearTransformMethod":
from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501
QutlassNvFP4LinearMethod, is_qutlass_fp4_scheme)
assert input_tfms or output_tfms
# TODO (@ksayers): implement QutlassLinearMethodNvFP4
# hadacore and fwht can be selected by Transform module
if is_qutlass_fp4_scheme(quant_scheme, input_tfms):
return QutlassNvFP4LinearMethod(quant_method, input_tfms,
output_tfms)
# hadacore or dense gemm is selected by Transform module
return cls(quant_method, input_tfms, output_tfms)
......@@ -129,11 +139,12 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
assert bias is None
x = self.quant_method.apply(layer, x, bias)
# TODO (@ksayers): Write a triton kernel to do this in parallel
# In most cases, input transforms are preferred over output transforms
# (@ksayers): confirm that this is done concurrently
if self.output_transform is not None:
for part_id, (start, length) in enumerate(self.partition_ranges):
x[:, start:start + length] = self.output_transform(
x[:, start:start + length], part_id=part_id)
x[:, start:start + length].contiguous(), part_id=part_id)
return x
......
......@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Hashable
from typing import Callable, Optional
from typing import Callable
import torch
from compressed_tensors.transform import TransformLocation, TransformScheme
from compressed_tensors.transform import (TransformArgs, TransformLocation,
TransformScheme)
from torch import Tensor
import vllm._custom_ops as ops
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import LinearBase
......@@ -28,16 +30,12 @@ class HadamardTransform(torch.nn.Module):
transforms: dict[int, TransformTuple] # info parsed from transforms config
weight: SharedWeightParameter # container for shared tensors
kernel: Callable # function used during application
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
def __init__(self,
transforms: dict[int, TransformTuple],
layer: torch.nn.Module,
weight_loader: Callable,
def __init__(self, transforms: dict[int, TransformTuple],
layer: torch.nn.Module, weight_loader: Callable,
input_size_per_partition: int,
output_partition_sizes: list[int],
kernel: Optional[Callable] = None):
output_partition_sizes: list[int]):
super().__init__()
self.transforms = transforms
self.scales = {}
......@@ -55,7 +53,7 @@ class HadamardTransform(torch.nn.Module):
for part_index, (_scheme_name, scheme,
args) in self.transforms.items():
output_size = output_partition_sizes[part_index]
weight_size = self._get_weight_size(layer, args.location,
weight_size = self._get_weight_size(layer, scheme, args,
input_size, output_size)
data_key = self._get_data_key(scheme, weight_size)
......@@ -69,9 +67,6 @@ class HadamardTransform(torch.nn.Module):
# validate that shared tensors and schemes are correct
self._validate_input_transforms()
# select kernel based on transform schemes
self.kernel = self._infer_kernel() if kernel is None else kernel
def process_weights_after_loading(self):
for part_id in self.weight.partitions:
data = self.weight.partitions[part_id].data
......@@ -90,32 +85,59 @@ class HadamardTransform(torch.nn.Module):
if part_id not in self.weight.partitions:
return value
# use hadacore if possible
if self.transforms[part_id].scheme.type == "hadamard":
if self.transforms[part_id].scheme.head_dim is not None:
weight_size = self.transforms[part_id].scheme.head_dim
value = value.unflatten(-1, (-1, weight_size))
value = ops.hadacore_transform(value)
value = value.flatten(-2, -1)
return value
# sylvester transforms are symmetric, inv => transpose => original
return ops.hadacore_transform(value)
# fall back to dense
else:
weight = self.weight.partitions[part_id]
weight = weight if self.transforms[
part_id].args.inverse else weight.T # linear := x(W.T)
scale = self.scales[part_id]
return self.kernel(self, value.to(weight.dtype), weight, None).to(
value.dtype) * scale
if self.transforms[part_id].scheme.head_dim is not None:
value = value.unflatten(-1, (-1, weight.size(0)))
value = dispatch_unquantized_gemm()(self, value.to(
weight.dtype), weight, None).to(value.dtype) * scale
value = value.flatten(-2, -1)
return value
return dispatch_unquantized_gemm()(self, value.to(
weight.dtype), weight, None).to(value.dtype) * scale
def _get_data_key(self, scheme: TransformScheme,
weight_size: int) -> Hashable:
return (id(scheme), weight_size)
def _get_weight_size(self, layer: torch.nn.Module,
location: TransformLocation, input_size: int,
def _get_weight_size(self, layer: torch.nn.Module, scheme: TransformScheme,
args: TransformArgs, input_size: int,
output_size: int) -> int:
if scheme.head_dim is not None:
return scheme.head_dim
if isinstance(layer, LinearBase):
if location == TransformLocation.INPUT:
if args.location == TransformLocation.INPUT:
return input_size
elif location == TransformLocation.OUTPUT:
elif args.location == TransformLocation.OUTPUT:
return output_size
elif isinstance(layer, VocabParallelEmbedding):
if location == TransformLocation.INPUT:
if args.location == TransformLocation.INPUT:
return output_size
elif location == TransformLocation.OUTPUT:
elif args.location == TransformLocation.OUTPUT:
return input_size
raise ValueError()
......@@ -129,7 +151,3 @@ class HadamardTransform(torch.nn.Module):
for partition in self.weight.partitions.values():
if partition.data.data_ptr() != first_data.data_ptr():
raise ValueError("")
def _infer_kernel(self) -> Callable:
# TODO (@ksayers): use fwht, hadacore
return dispatch_unquantized_gemm()
......@@ -4,18 +4,43 @@ from typing import Optional
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme, CompressedTensorsW4A4Fp4)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod)
CompressedTensorsLinearTransformMethod, TransformTuple)
__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
# Because qutlass fuses hadamard with quantization, it cannot automatically be
# composed with kernels in the way CompressedTensorsLinearTransformMethod does.
# Therefore, a separate scheme must be created for each quantized dtype
class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod):
def is_qutlass_fp4_scheme(quant_scheme: Optional[CompressedTensorsScheme],
input_tfms: dict[int, TransformTuple]) -> bool:
return isinstance(
quant_scheme,
(CompressedTensorsW4A4Fp4, )) and len(input_tfms) == 1 and input_tfms[
0].scheme.head_dim == quant_scheme.group_size
class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, input_size, output_size,
params_dtype, **extra_weight_attrs):
# initializes fp4 qparams
assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4, ))
ret = super().create_weights(layer, input_size_per_partition,
output_partition_sizes, input_size,
output_size, params_dtype,
**extra_weight_attrs)
assert self.input_transform is not None
assert len(self.input_transform.weight) == 1
assert self.input_transform.weight[0].size(
0) == layer.scheme.group_size
return ret
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# fused hadamard quant linear method
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment