"docs/vscode:/vscode.git/clone" did not exist on "e858bfe05167a3bbb064e283da5a1a7709dee24e"
Unverified Commit 0e22cd61 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

Revert "[Llama4,Quantization] Simplify and generalize logic for Q/K...

Revert "[Llama4,Quantization] Simplify and generalize logic for Q/K permutations in quantized self-attn layers " (#34997)
parent ea5f903f
...@@ -44,6 +44,9 @@ from vllm.model_executor.layers.linear import ( ...@@ -44,6 +44,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors import (
compressed_tensors as ct,
)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
...@@ -828,38 +831,74 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts): ...@@ -828,38 +831,74 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
name: str, name: str,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]: ) -> tuple[str, torch.Tensor]:
modules = name.split(".") # Helper function to permute the weight's channels
# Permute Q/K weights and corresponding scales for rotary embedding. def permute(
# This pathway is validated against modelopt and compressed-tensors ckpts, w: torch.Tensor,
# and for per-tensor, per-group (e.g. GPTQ), and per-channel quant schemes. n_heads: int,
# Note: permutations are not feasible only for per-block (e.g. DeepSeek 128x128) is_nvfp4_weight_scale: bool,
# For per-block quantization, consider not quantizing q/k_proj. is_ct_int8_or_fp8_weight_scale: bool,
is_weight = modules[-1] in ("weight", "weight_packed") ):
is_weight_scale = ( # Calculate the expected shape of the weight.
modules[-1] == "weight_scale" # Do not rely on w's shape, as it may be in another layout.
and loaded_weight.numel() > 1 # no need to permute per-tensor scales attn_in = self.config.head_dim * n_heads
) attn_out = (
is_k_proj = "wk" in modules or "k_proj" in modules self.config.hidden_size
is_q_proj = "wq" in modules or "q_proj" in modules if not is_ct_int8_or_fp8_weight_scale
else w.shape[-1]
if (is_weight or is_weight_scale) and (is_k_proj or is_q_proj): )
original_ndim = loaded_weight.ndim
if original_ndim == 1: # If the weight is FP4 packed as uint8, we need to divide attn_out
loaded_weight = loaded_weight.unsqueeze(-1) # by 2.
if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
f_out, f_in = loaded_weight.shape attn_out = attn_out // 2
n_heads = (
self.config.num_key_value_heads # If the weight is a weight scale, we need to divide attn_out by
if is_k_proj # block size, which is currently 16.
else self.config.num_attention_heads elif (
) w.dtype == torch.float8_e4m3fn
loaded_weight = ( and is_nvfp4_weight_scale
loaded_weight.view(n_heads, f_out // n_heads // 2, 2, f_in) and w.shape[1] * 16 == attn_out
):
attn_out = attn_out // 16
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2) .transpose(1, 2)
.reshape(f_out, f_in) .reshape(attn_in, attn_out)
) )
if original_ndim == 1: modules = name.split(".")
loaded_weight = loaded_weight.squeeze(-1)
# Permute Q/K weights and weight block scales for rotary embedding
is_weight = modules[-1] == "weight"
is_nvfp4_weight_scale = (
modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
)
is_ct_int8_or_fp8_weight_scale = False
if modules[-1] == "weight_scale" and isinstance(
self.model.quant_config, ct.CompressedTensorsConfig
):
from compressed_tensors import CompressionFormat
is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]
if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale:
if "wk" in modules or "k_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_key_value_heads,
is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
)
elif "wq" in modules or "q_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_attention_heads,
is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
)
return name, loaded_weight return name, loaded_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment