Unverified Commit e9cd6911 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Bugfix] Fix Sparse24 Compressed Tensors models (#33446)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 80f2ba6e
......@@ -6,11 +6,11 @@
#include "cutlass_extensions/common.hpp"
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
// sparse CUTLASS kernels need at least
// sparse CUTLASS kernels need exactly hopper and are not forward compatible
// CUDA 12.2 and SM90 (Hopper)
#if defined CUDA_VERSION
return CUDA_VERSION >= 12020 && cuda_device_capability >= 90;
return CUDA_VERSION >= 12020 && cuda_device_capability == 90;
#endif
return false;
......@@ -98,7 +98,7 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a) {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_sparse_compress for a compute capability less than "
"No compiled cutlass_sparse_compress for a compute capability equal to "
"CUDA device capability: ",
version_num);
}
......@@ -207,13 +207,14 @@ class CompressedTensorsConfig(QuantizationConfig):
# because Attention quantization on its own is not supported by vLLM.
# It is coupled with KV-cache quantization, and if scales are present in the
# checkpoint, they will be used properly.
if "config_groups" in config:
grps_without_attn_quant = {}
for k, v in config["config_groups"].items():
# e.g. LlamaAttention, Qwen3Attention, etc.
if len(v["targets"]) == 1 and v["targets"][0].endswith("Attention"):
logger.warning(
"Skipping CompressedTensors config group for %s. Attention quant "
"is coupled with KV-cache quantization in vLLM.",
"Skipping CompressedTensors config group for %s. Attention "
"quant is coupled with KV-cache quantization in vLLM.",
v["targets"][0],
)
continue
......
......@@ -261,6 +261,7 @@ def get_quant_config(
if (
hf_quant_config is not None
and hf_quant_config.get("quant_method") == "compressed-tensors"
and "config_groups" in hf_quant_config
):
if hf_text_config is not None:
n_heads = getattr(hf_text_config, "num_attention_heads", None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment