Unverified Commit a0b26701 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Transform] Deterministic Hadacore Transforms (#24106)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent c4afdb69
...@@ -783,6 +783,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -783,6 +783,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
endif() endif()
# Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}")
if(HADACORE_ARCHS)
set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${HADACORE_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building hadacore")
endif()
# if CUDA endif # if CUDA endif
endif() endif()
......
...@@ -347,6 +347,8 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle( ...@@ -347,6 +347,8 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t open_mem_handle(torch::Tensor& mem_handle); int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer); void free_shared_buffer(int64_t buffer);
torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace);
#ifdef USE_ROCM #ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size, fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size = std::nullopt); std::optional<int64_t> qr_max_size = std::nullopt);
......
This diff is collapsed.
...@@ -613,6 +613,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -613,6 +613,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int pad_slot_id) -> ()"); "int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
// Hadamard transforms
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
#ifndef USE_ROCM #ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor. // Compute per-token-group FP8 quantized tensor and scaling factor.
ops.def( ops.def(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
from compressed_tensors.transform import deterministic_hadamard_matrix
from vllm import _custom_ops as ops
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("hidden_dim", [2**n for n in range(10)])
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
x = torch.eye(hidden_dim, dtype=dtype, device=device)
hadamard = deterministic_hadamard_matrix(
hidden_dim, dtype=torch.float64, device="cuda") / math.sqrt(hidden_dim)
y = ops.hadacore_transform(x.clone())
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)
assert torch.allclose(y, y_true)
y = ops.hadacore_transform(y)
assert torch.allclose(y, x)
...@@ -2011,3 +2011,27 @@ def onednn_scaled_mm( ...@@ -2011,3 +2011,27 @@ def onednn_scaled_mm(
input_zp_adj, bias, dnnl_handler.handler) input_zp_adj, bias, dnnl_handler.handler)
return output return output
def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832)
kernels. Note that these kernels exploit the recursive properties of
Sylvester Hadamards, and therefore do not require transform weight data
Note that sylvester hadamard transforms are also symmetric, which means that
this function is also applies the (transpose <=> inverse) transform.
:param x: value to be transformed inplace
:param inplace: modify value in place
:return: value after transformation
"""
return torch.ops._C.hadacore_transform(x, inplace)
if hasattr(torch.ops._C, "hadacore_transform"):
@register_fake("_C::hadacore_transform")
def _hadacore_transform_fake(x: torch.Tensor,
inplace: bool) -> torch.Tensor:
return torch.empty_like(x) if not inplace else x
...@@ -129,7 +129,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -129,7 +129,7 @@ class CompressedTensorsConfig(QuantizationConfig):
# choose transform method # choose transform method
if any((input_tfms, output_tfms)): if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes( return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, input_tfms, output_tfms) quant_method, quant_scheme, input_tfms, output_tfms)
else: else:
return quant_method return quant_method
......
...@@ -12,6 +12,8 @@ from compressed_tensors.utils import is_match ...@@ -12,6 +12,8 @@ from compressed_tensors.utils import is_match
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
LinearMethodBase, LinearMethodBase,
QKVCrossParallelLinear) QKVCrossParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
HadamardTransform) HadamardTransform)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
...@@ -26,14 +28,22 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase): ...@@ -26,14 +28,22 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
@classmethod @classmethod
def from_schemes( def from_schemes(
cls, quant_method: LinearMethodBase, input_tfms: dict[int, cls,
TransformTuple], quant_method: LinearMethodBase,
output_tfms: dict[int, TransformTuple] quant_scheme: Optional[CompressedTensorsScheme],
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple],
) -> "CompressedTensorsLinearTransformMethod": ) -> "CompressedTensorsLinearTransformMethod":
from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501
QutlassNvFP4LinearMethod, is_qutlass_fp4_scheme)
assert input_tfms or output_tfms assert input_tfms or output_tfms
# TODO (@ksayers): implement QutlassLinearMethodNvFP4 if is_qutlass_fp4_scheme(quant_scheme, input_tfms):
# hadacore and fwht can be selected by Transform module return QutlassNvFP4LinearMethod(quant_method, input_tfms,
output_tfms)
# hadacore or dense gemm is selected by Transform module
return cls(quant_method, input_tfms, output_tfms) return cls(quant_method, input_tfms, output_tfms)
...@@ -129,11 +139,12 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase): ...@@ -129,11 +139,12 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
assert bias is None assert bias is None
x = self.quant_method.apply(layer, x, bias) x = self.quant_method.apply(layer, x, bias)
# TODO (@ksayers): Write a triton kernel to do this in parallel # In most cases, input transforms are preferred over output transforms
# (@ksayers): confirm that this is done concurrently
if self.output_transform is not None: if self.output_transform is not None:
for part_id, (start, length) in enumerate(self.partition_ranges): for part_id, (start, length) in enumerate(self.partition_ranges):
x[:, start:start + length] = self.output_transform( x[:, start:start + length] = self.output_transform(
x[:, start:start + length], part_id=part_id) x[:, start:start + length].contiguous(), part_id=part_id)
return x return x
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections.abc import Hashable from collections.abc import Hashable
from typing import Callable, Optional from typing import Callable
import torch import torch
from compressed_tensors.transform import TransformLocation, TransformScheme from compressed_tensors.transform import (TransformArgs, TransformLocation,
TransformScheme)
from torch import Tensor from torch import Tensor
import vllm._custom_ops as ops
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
...@@ -28,16 +30,12 @@ class HadamardTransform(torch.nn.Module): ...@@ -28,16 +30,12 @@ class HadamardTransform(torch.nn.Module):
transforms: dict[int, TransformTuple] # info parsed from transforms config transforms: dict[int, TransformTuple] # info parsed from transforms config
weight: SharedWeightParameter # container for shared tensors weight: SharedWeightParameter # container for shared tensors
kernel: Callable # function used during application
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0)) scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
def __init__(self, def __init__(self, transforms: dict[int, TransformTuple],
transforms: dict[int, TransformTuple], layer: torch.nn.Module, weight_loader: Callable,
layer: torch.nn.Module,
weight_loader: Callable,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: list[int], output_partition_sizes: list[int]):
kernel: Optional[Callable] = None):
super().__init__() super().__init__()
self.transforms = transforms self.transforms = transforms
self.scales = {} self.scales = {}
...@@ -55,7 +53,7 @@ class HadamardTransform(torch.nn.Module): ...@@ -55,7 +53,7 @@ class HadamardTransform(torch.nn.Module):
for part_index, (_scheme_name, scheme, for part_index, (_scheme_name, scheme,
args) in self.transforms.items(): args) in self.transforms.items():
output_size = output_partition_sizes[part_index] output_size = output_partition_sizes[part_index]
weight_size = self._get_weight_size(layer, args.location, weight_size = self._get_weight_size(layer, scheme, args,
input_size, output_size) input_size, output_size)
data_key = self._get_data_key(scheme, weight_size) data_key = self._get_data_key(scheme, weight_size)
...@@ -69,9 +67,6 @@ class HadamardTransform(torch.nn.Module): ...@@ -69,9 +67,6 @@ class HadamardTransform(torch.nn.Module):
# validate that shared tensors and schemes are correct # validate that shared tensors and schemes are correct
self._validate_input_transforms() self._validate_input_transforms()
# select kernel based on transform schemes
self.kernel = self._infer_kernel() if kernel is None else kernel
def process_weights_after_loading(self): def process_weights_after_loading(self):
for part_id in self.weight.partitions: for part_id in self.weight.partitions:
data = self.weight.partitions[part_id].data data = self.weight.partitions[part_id].data
...@@ -90,32 +85,59 @@ class HadamardTransform(torch.nn.Module): ...@@ -90,32 +85,59 @@ class HadamardTransform(torch.nn.Module):
if part_id not in self.weight.partitions: if part_id not in self.weight.partitions:
return value return value
# use hadacore if possible
if self.transforms[part_id].scheme.type == "hadamard":
if self.transforms[part_id].scheme.head_dim is not None:
weight_size = self.transforms[part_id].scheme.head_dim
value = value.unflatten(-1, (-1, weight_size))
value = ops.hadacore_transform(value)
value = value.flatten(-2, -1)
return value
# sylvester transforms are symmetric, inv => transpose => original
return ops.hadacore_transform(value)
# fall back to dense
else:
weight = self.weight.partitions[part_id] weight = self.weight.partitions[part_id]
weight = weight if self.transforms[ weight = weight if self.transforms[
part_id].args.inverse else weight.T # linear := x(W.T) part_id].args.inverse else weight.T # linear := x(W.T)
scale = self.scales[part_id] scale = self.scales[part_id]
return self.kernel(self, value.to(weight.dtype), weight, None).to(
value.dtype) * scale if self.transforms[part_id].scheme.head_dim is not None:
value = value.unflatten(-1, (-1, weight.size(0)))
value = dispatch_unquantized_gemm()(self, value.to(
weight.dtype), weight, None).to(value.dtype) * scale
value = value.flatten(-2, -1)
return value
return dispatch_unquantized_gemm()(self, value.to(
weight.dtype), weight, None).to(value.dtype) * scale
def _get_data_key(self, scheme: TransformScheme, def _get_data_key(self, scheme: TransformScheme,
weight_size: int) -> Hashable: weight_size: int) -> Hashable:
return (id(scheme), weight_size) return (id(scheme), weight_size)
def _get_weight_size(self, layer: torch.nn.Module, def _get_weight_size(self, layer: torch.nn.Module, scheme: TransformScheme,
location: TransformLocation, input_size: int, args: TransformArgs, input_size: int,
output_size: int) -> int: output_size: int) -> int:
if scheme.head_dim is not None:
return scheme.head_dim
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if location == TransformLocation.INPUT: if args.location == TransformLocation.INPUT:
return input_size return input_size
elif location == TransformLocation.OUTPUT: elif args.location == TransformLocation.OUTPUT:
return output_size return output_size
elif isinstance(layer, VocabParallelEmbedding): elif isinstance(layer, VocabParallelEmbedding):
if location == TransformLocation.INPUT: if args.location == TransformLocation.INPUT:
return output_size return output_size
elif location == TransformLocation.OUTPUT: elif args.location == TransformLocation.OUTPUT:
return input_size return input_size
raise ValueError() raise ValueError()
...@@ -129,7 +151,3 @@ class HadamardTransform(torch.nn.Module): ...@@ -129,7 +151,3 @@ class HadamardTransform(torch.nn.Module):
for partition in self.weight.partitions.values(): for partition in self.weight.partitions.values():
if partition.data.data_ptr() != first_data.data_ptr(): if partition.data.data_ptr() != first_data.data_ptr():
raise ValueError("") raise ValueError("")
def _infer_kernel(self) -> Callable:
# TODO (@ksayers): use fwht, hadacore
return dispatch_unquantized_gemm()
...@@ -4,18 +4,43 @@ from typing import Optional ...@@ -4,18 +4,43 @@ from typing import Optional
import torch import torch
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme, CompressedTensorsW4A4Fp4)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod) CompressedTensorsLinearTransformMethod, TransformTuple)
__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
# Because qutlass fuses hadamard with quantization, it cannot automatically be
# composed with kernels in the way CompressedTensorsLinearTransformMethod does. def is_qutlass_fp4_scheme(quant_scheme: Optional[CompressedTensorsScheme],
# Therefore, a separate scheme must be created for each quantized dtype input_tfms: dict[int, TransformTuple]) -> bool:
class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod): return isinstance(
quant_scheme,
(CompressedTensorsW4A4Fp4, )) and len(input_tfms) == 1 and input_tfms[
0].scheme.head_dim == quant_scheme.group_size
class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, input_size, output_size,
params_dtype, **extra_weight_attrs):
# initializes fp4 qparams
assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4, ))
ret = super().create_weights(layer, input_size_per_partition,
output_partition_sizes, input_size,
output_size, params_dtype,
**extra_weight_attrs)
assert self.input_transform is not None
assert len(self.input_transform.weight) == 1
assert self.input_transform.weight[0].size(
0) == layer.scheme.group_size
return ret
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# fused hadamard quant linear method
raise NotImplementedError() raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment