Unverified Commit 60508ffd authored by Dipika Sikka's avatar Dipika Sikka Committed by GitHub
Browse files

[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (#10995)


Co-authored-by: default avatarFaraz Shahsavan <faraz.shahsavan@gmail.com>
Co-authored-by: default avatarilmarkov <markovilya197@gmail.com>
Co-authored-by: default avatarRahul Tuli <rahul@neuralmagic.com>
Co-authored-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
parent f04e407e
"""Tests for sparse cutlass kernels
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from typing import Optional, Tuple, Type
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)
def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)
# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0
return pruned.reshape(original_shape)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
b = prune_to_2_4(b.t()).t()
if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t())
# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias
return output
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
big_m = 1024
m, n, k = 512, 512, 512
# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
big_m, n, k)
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
......@@ -10,9 +10,11 @@ from compressed_tensors.quantization import QuantizationType
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
CompressedTensors24, CompressedTensorsLinearMethod,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
from vllm.platforms import current_platform
@pytest.mark.parametrize(
......@@ -208,3 +210,98 @@ def test_compressed_tensors_kv_cache(vllm_runner):
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
assert qkv_proj.scheme.input_quant.strategy == input_strategy
assert qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
"token"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
"channel", "tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
"tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor", "token"),
])
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
"channel", "token"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor",
"tensor"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor", "token"),
])
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize(
"args_2of4",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
......@@ -21,6 +21,8 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
......
......@@ -26,6 +26,10 @@ do
export QUANTIZATION=${array[0]}
export MODEL_NAME=${array[1]}
export REVISION=${array[2]}
# If array length is larger than 3, then MIN_CAPABILITY is provided
if [ ${#array[@]} -gt 3 ]; then
export MIN_CAPABILITY=${array[3]}
fi
pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$?
if [[ $LOCAL_SUCCESS == 0 ]]; then
......
import os
import pytest
import torch
from vllm.platforms import current_platform
MAX_MODEL_LEN = 1024
MODEL_NAME = os.environ.get("MODEL_NAME",
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
REVISION = os.environ.get("REVISION", "main")
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "89")
@pytest.mark.skipif(
not current_platform.has_device_capability(int(MIN_CAPABILITY)),
reason="Current system does not have minimum capability.")
def test_weight_loading(vllm_runner):
"""
Test parameter weight loading with tp>1.
......
......@@ -552,6 +552,109 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
return out
def cutlass_sparse_compress(a: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""
Compresses a sparse matrix for use with Cutlass sparse operations.
This function takes a dense tensor and compresses it into two components:
non-zero elements and metadata. The compressed representation is compatible
with Cutlass sparse kernels.
Args:
a (torch.Tensor):
The input tensor to be compressed. Must have one of the following data types:
- `torch.int8`
- `torch.float8_e4m3fn`
- `torch.bfloat16`
- `torch.float16`
Returns:
Tuple[torch.Tensor, torch.Tensor]:
A tuple containing:
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
Raises:
ValueError: If the compression operation fails.
Notes:
- The `a_meta` tensor has a data type of `torch.uint8`.
- Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
- The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
- The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
"""
assert (a.dtype in [
torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16
])
assert (a.is_contiguous())
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
elemsPerMetaElem = 4
m = a.shape[0]
k = a.shape[1]
assert (k % 2 == 0)
a_nzs = torch.empty((m, k // 2), dtype=a.dtype, device=a.device)
a_meta = torch.empty((m, k // 2 // elemsPerMetaElem),
dtype=torch.uint8,
device=a.device)
if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)):
raise ValueError
assert (a_nzs.is_contiguous())
assert (a_meta.is_contiguous())
return a_nzs, a_meta
def cutlass_scaled_sparse_mm(
a: torch.Tensor,
bt_nzs: torch.Tensor,
bt_meta: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Performs a scaled sparse matrix multiplication using Cutlass.
Steps:
1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
`a = torch.randn((m, k), device='cuda')`.
2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
`b = torch.randn((k, n), device='cuda')`.
3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
`b = prune_to_2_4(b, dim=0)`.
4. Compress the transposed sparse matrix `b.t()`:
`bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.
5. Perform sparse matrix multiplication using the compressed matrix,
applying scaling factors for `a` and `b`, and the output data type:
`out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.
Returns:
- The result of the scaled sparse matrix multiplication.
"""
assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == bt_nzs.shape[0] \
and bias.dtype == out_dtype
m = a.shape[0]
n = bt_nzs.shape[0]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a,
scale_b, bias)
return out
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
......
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Literal, Optional, cast
import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.config import (CompressionFormat,
SparsityCompressionConfig,
SparsityStructure)
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
......@@ -15,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
......@@ -27,20 +29,29 @@ from vllm.platforms import current_platform
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
class CompressedTensorsConfig(QuantizationConfig):
def __init__(self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None):
def __init__(
self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
):
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme
self.sparsity_scheme_map = sparsity_scheme_map
self.config = config
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
......@@ -78,8 +89,50 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
ignore: List[str] = cast(List[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
config=config)
return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
sparsity_scheme_map=sparsity_scheme_map,
config=config,
)
@classmethod
def _sparsity_scheme_map_from_config(
cls, config: Dict[str,
Any]) -> Dict[str, SparsityCompressionConfig]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
sparsity compression configurations
"""
if (sparsity_config := config.get(SPARSITY_CONFIG_NAME)) is None:
return dict()
sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config)
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparsity_config.targets or list()
}
return sparse_scheme_map
@classmethod
def _quantization_scheme_map_from_config(
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: Dict[str, Any] = dict()
ignore = cast(List[str], config.get("ignore"))
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
......@@ -90,12 +143,14 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for _, quant_config in config["config_groups"].items():
config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj(
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))
target_scheme_map[target]["input_activations"] = None
......@@ -110,13 +165,9 @@ class CompressedTensorsConfig(QuantizationConfig):
"weights"].type == QuantizationType.FLOAT
else:
target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj(
"input_activations"] = QuantizationArgs.model_validate( # noqa: E501
quant_config.get("input_activations"))
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
kv_cache_scheme=config.get("kv_cache_scheme"))
return target_scheme_map
@classmethod
def get_config_filenames(cls) -> List[str]:
......@@ -315,23 +366,105 @@ class CompressedTensorsConfig(QuantizationConfig):
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
# Find the quant_scheme
scheme_dict = self.target_scheme_map[matched_target]
scheme = self._get_scheme_from_parts(
weight_quant=scheme_dict["weights"],
input_quant=scheme_dict["input_activations"])
# Will be empty for models with only sparsity
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
elif self.sparsity_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
weight_quant = None
input_quant = None
# For models with sparsity, assumes that the sparse layers are also
# quantized for cutlass 2:4 support
sparsity_scheme: Optional[
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
matched_target)
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant,
input_quant=input_quant,
)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme
@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig] = None
) -> bool:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
- Overarching condition: Sparsity Structure is 2:4
- Unquantized cases are supported
- Weight only quantization is not-supported
- Supported weight quantization strategies are TENSOR and CHANNEL
- Supported input quantization strategies are TENSOR and TOKEN
- Only 8 bit quantization is supported
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
if not is_valid_sparsity:
return False
# Unquantized cases are supported
if weight_quant is None and input_quant is None:
return True
# Weight only quantization is not-supported
if weight_quant is not None and input_quant is None:
return False
supported_weight_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.CHANNEL.value
]
assert weight_quant is not None
assert input_quant is not None
if weight_quant.strategy not in supported_weight_quant_strategies:
return False
supported_input_quant_strategies = [
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
]
if input_quant.strategy not in supported_input_quant_strategies:
return False
return weight_quant.num_bits == input_quant.num_bits == 8
class CompressedTensorsLinearMethod(LinearMethodBase):
......
......@@ -7,13 +7,12 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS",
"W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensorsScheme", "CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24"
]
from typing import Callable, List, Optional
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensors24"]
class CompressedTensors24(CompressedTensorsScheme):
def __init__(self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None):
self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
@classmethod
def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far
return 90
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
# parameter to store uncompressed weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=self.weights_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
# Check if quantized, not just 2:4 Sparse
if self.quantized:
if (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.CHANNEL.value):
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# input quant will be non-none
if self.input_quant and not self.input_quant.dynamic:
# register input quant scale
assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
else:
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
input_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False)
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if self.quantized:
scale = None
if hasattr(layer, "input_scale"):
scale = layer.input_scale
if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None:
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)
else:
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x
out = ops.cutlass_scaled_sparse_mm(a=q_input,
bt_nzs=layer.weight,
bt_meta=layer.meta,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=self.output_dtype,
bias=bias)
assert out.is_contiguous()
return out
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
if not self.quantized:
return params_dtype
assert self.weight_quant is not None
assert self.input_quant is not None
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
if not is_8_bits:
raise ValueError("Cutlass only supports 8-bit quantization")
if (self.weight_quant.type == QuantizationType.FLOAT
and self.input_quant.type == QuantizationType.FLOAT):
return torch.float8_e4m3fn
if (self.weight_quant.type == QuantizationType.INT
and self.input_quant.type == QuantizationType.INT):
return torch.int8
raise ValueError("Quantization type not supported by Cutlass")
def check_24(tensor):
new_tensor = tensor.view(-1, 4)
zero_counts = (new_tensor == 0).sum(dim=1)
return (zero_counts >= 2).all().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment