Unverified Commit 11c7ace3 authored by Eldar Kurtić's avatar Eldar Kurtić Committed by GitHub
Browse files

[Bugfix] Enable attn quantization of Llama-4 by correctly permuting scales for...


[Bugfix] Enable attn quantization of Llama-4 by correctly permuting scales for rope (int8, fp8) (#34243)
Signed-off-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatarYour Name <you@example.com>
parent be7f3d5d
...@@ -44,6 +44,9 @@ from vllm.model_executor.layers.linear import ( ...@@ -44,6 +44,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors import (
compressed_tensors as ct,
)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
...@@ -829,11 +832,20 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts): ...@@ -829,11 +832,20 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]: ) -> tuple[str, torch.Tensor]:
# Helper function to permute the weight's channels # Helper function to permute the weight's channels
def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): def permute(
w: torch.Tensor,
n_heads: int,
is_nvfp4_weight_scale: bool,
is_ct_int8_or_fp8_weight_scale: bool,
):
# Calculate the expected shape of the weight. # Calculate the expected shape of the weight.
# Do not rely on w's shape, as it may be in another layout. # Do not rely on w's shape, as it may be in another layout.
attn_in = self.config.head_dim * n_heads attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size attn_out = (
self.config.hidden_size
if not is_ct_int8_or_fp8_weight_scale
else w.shape[-1]
)
# If the weight is FP4 packed as uint8, we need to divide attn_out # If the weight is FP4 packed as uint8, we need to divide attn_out
# by 2. # by 2.
...@@ -844,7 +856,7 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts): ...@@ -844,7 +856,7 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
# block size, which is currently 16. # block size, which is currently 16.
elif ( elif (
w.dtype == torch.float8_e4m3fn w.dtype == torch.float8_e4m3fn
and is_weight_scale and is_nvfp4_weight_scale
and w.shape[1] * 16 == attn_out and w.shape[1] * 16 == attn_out
): ):
attn_out = attn_out // 16 attn_out = attn_out // 16
...@@ -862,19 +874,31 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts): ...@@ -862,19 +874,31 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
is_nvfp4_weight_scale = ( is_nvfp4_weight_scale = (
modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
) )
is_ct_int8_or_fp8_weight_scale = False
if is_weight or is_nvfp4_weight_scale: if modules[-1] == "weight_scale" and isinstance(
self.model.quant_config, ct.CompressedTensorsConfig
):
from compressed_tensors import CompressionFormat
is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]
if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale:
if "wk" in modules or "k_proj" in modules: if "wk" in modules or "k_proj" in modules:
loaded_weight = permute( loaded_weight = permute(
loaded_weight, loaded_weight,
self.config.num_key_value_heads, self.config.num_key_value_heads,
is_nvfp4_weight_scale, is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
) )
elif "wq" in modules or "q_proj" in modules: elif "wq" in modules or "q_proj" in modules:
loaded_weight = permute( loaded_weight = permute(
loaded_weight, loaded_weight,
self.config.num_attention_heads, self.config.num_attention_heads,
is_nvfp4_weight_scale, is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
) )
return name, loaded_weight return name, loaded_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment