Unverified Commit 978aed53 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Kernel][Attention] Separate `Attention.kv_scale` into `k_scale` and `v_scale` (#6081)

parent 160e1d8c
...@@ -131,7 +131,8 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -131,7 +131,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
...@@ -146,7 +147,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -146,7 +147,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
......
...@@ -296,7 +296,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -296,7 +296,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -336,7 +337,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -336,7 +337,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
...@@ -456,7 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -456,7 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale, k_scale,
v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -144,7 +144,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -144,7 +144,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -158,7 +159,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -158,7 +159,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
...@@ -176,7 +177,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -176,7 +177,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
self.kv_cache_dtype, kv_scale) self.kv_cache_dtype, k_scale,
v_scale)
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
...@@ -239,7 +241,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -239,7 +241,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale, k_scale,
v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -427,7 +427,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -427,7 +427,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
value: Optional[torch.Tensor], value: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -531,7 +532,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -531,7 +532,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
value_cache, value_cache,
updated_slot_mapping, updated_slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale) k_scale, v_scale)
if attn_type != AttentionType.ENCODER: if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill. # Decoder self-attention supports chunked prefill.
...@@ -620,7 +621,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -620,7 +621,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale, k_scale,
v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -47,13 +47,14 @@ class Attention(nn.Module): ...@@ -47,13 +47,14 @@ class Attention(nn.Module):
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads
# The default kv_scale is set to 1.0. This is ignored # The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with # when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along # expect the pre-quantized k/v_scale to be loaded along
# with the model weights. # with the model weights.
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self._kv_scale = 1.0 self._k_scale = 1.0
self._v_scale = 1.0
quant_method = quant_config.get_quant_method( quant_method = quant_config.get_quant_method(
self) if quant_config else None self) if quant_config else None
if quant_method is not None: if quant_method is not None:
...@@ -66,8 +67,8 @@ class Attention(nn.Module): ...@@ -66,8 +67,8 @@ class Attention(nn.Module):
"fp8 checkpoints.") "fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter # When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint. # "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back to self._kv_scale # The k/v_scale will then be converted back to
# in a native float32 value after weight loading. # self._kv_scale in a native float32 value after weight loading
self.quant_method = quant_method self.quant_method = quant_method
self.quant_method.create_weights(self) self.quant_method.create_weights(self)
...@@ -98,7 +99,8 @@ class Attention(nn.Module): ...@@ -98,7 +99,8 @@ class Attention(nn.Module):
value, value,
kv_cache, kv_cache,
attn_metadata, attn_metadata,
self._kv_scale, self._k_scale,
self._v_scale,
attn_type=attn_type) attn_type=attn_type)
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
...@@ -45,7 +45,8 @@ class PagedAttention: ...@@ -45,7 +45,8 @@ class PagedAttention:
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
*args, *args,
) -> None: ) -> None:
ipex_modules.PagedAttention.reshape_and_cache( ipex_modules.PagedAttention.reshape_and_cache(
...@@ -64,7 +65,8 @@ class PagedAttention: ...@@ -64,7 +65,8 @@ class PagedAttention:
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_scale: float, k_scale: float,
v_scale: float,
*args, *args,
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty_like(query) output = torch.empty_like(query)
......
...@@ -66,7 +66,8 @@ class PagedAttention: ...@@ -66,7 +66,8 @@ class PagedAttention:
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
) -> None: ) -> None:
ops.reshape_and_cache( ops.reshape_and_cache(
key, key,
...@@ -75,7 +76,8 @@ class PagedAttention: ...@@ -75,7 +76,8 @@ class PagedAttention:
value_cache, value_cache,
slot_mapping.flatten(), slot_mapping.flatten(),
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
@staticmethod @staticmethod
...@@ -90,7 +92,8 @@ class PagedAttention: ...@@ -90,7 +92,8 @@ class PagedAttention:
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_scale: float, k_scale: float,
v_scale: float,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -135,7 +138,8 @@ class PagedAttention: ...@@ -135,7 +138,8 @@ class PagedAttention:
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
tp_rank, tp_rank,
blocksparse_local_blocks, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_vert_stride,
...@@ -172,7 +176,8 @@ class PagedAttention: ...@@ -172,7 +176,8 @@ class PagedAttention:
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
tp_rank, tp_rank,
blocksparse_local_blocks, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_vert_stride,
......
...@@ -196,6 +196,15 @@ class ReplicatedLinear(LinearBase): ...@@ -196,6 +196,15 @@ class ReplicatedLinear(LinearBase):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
......
...@@ -407,31 +407,56 @@ class Fp8KVCacheMethod(QuantizeMethodBase): ...@@ -407,31 +407,56 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module): def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer. """Create "weight" (aka k_scale and v_scale) for an attention layer.
Args: Args:
layer: The layer that is using the QuantizeMethodBase factory. layer: The layer that is using the QuantizeMethodBase factory.
""" """
# Initialize the KV cache scale to 1.0 as the default value. # Initialize the KV cache scales to -1.0, which is an invalid value.
# If the kv_scale appears in the checkpoint, it will be # If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights. # overwritten when loading weights.
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor: def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint. # regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto": if layer.kv_cache_dtype != "auto":
kv_scale = layer.kv_scale.to("cpu").tolist() if layer.k_scale > 0.0 and layer.v_scale > 0.0:
if not isinstance(kv_scale, float): # We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = Parameter(torch.tensor(1.0), requires_grad=False)
v_scale = Parameter(torch.tensor(1.0), requires_grad=False)
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if not isinstance(k_scale, float) or not isinstance(
v_scale, float):
raise ValueError("Only support per-tensor scaling factor " raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache") "for fp8 KV cache")
layer._kv_scale = kv_scale
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: # These are used in the final Attention.forward()
layer._k_scale = k_scale
layer._v_scale = v_scale
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once( print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may " "Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"cause accuracy issues. Please make sure kv-cache scaling " "may cause accuracy issues. Please make sure k/v_scale "
"factor is available in the fp8 checkpoint.") "scaling factors are available in the fp8 checkpoint.")
del layer.kv_scale
del layer.k_scale
del layer.v_scale
...@@ -22,6 +22,7 @@ from vllm.logger import init_logger ...@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig, from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config) get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -431,11 +432,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: ...@@ -431,11 +432,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
def default_weight_loader(param: torch.Tensor, def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None: loaded_weight: torch.Tensor) -> None:
"""Default weight loader.""" """Default weight loader."""
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param.size() == loaded_weight.size() assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)
...@@ -462,3 +458,55 @@ def initialize_dummy_weights( ...@@ -462,3 +458,55 @@ def initialize_dummy_weights(
param.data.copy_(tmp_param) param.data.copy_(tmp_param)
else: else:
param.uniform_(low, high) param.uniform_(low, high)
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if name.endswith(".kv_scale"):
print_warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale")
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
if remapped_name not in params_dict:
print_warning_once(
f"Found kv_scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). kv_scale is "
"not loaded.")
return None
return remapped_name
possible_scale_names = [".k_scale", ".v_scale"]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
print_warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). {scale_name} is "
"not loaded.")
return None
return remapped_name
# If there were no matches, return the untouched param name
return name
...@@ -44,10 +44,10 @@ from vllm.model_executor.layers.sampler import Sampler ...@@ -44,10 +44,10 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_hip, print_warning_once from vllm.utils import is_hip
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers from .utils import is_pp_missing_parameter, make_layers
...@@ -460,18 +460,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -460,18 +460,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale. # Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"): name = maybe_remap_kv_scale_name(name, params_dict)
remapped_kv_scale_name = name.replace( if name is None:
".kv_scale", ".attn.kv_scale") continue
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
......
...@@ -42,10 +42,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -42,10 +42,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -415,19 +415,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -415,19 +415,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale. # Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"): name = maybe_remap_kv_scale_name(name, params_dict)
remapped_kv_scale_name = name.replace( if name is None:
".kv_scale", ".attn.kv_scale") continue
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -43,10 +43,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -43,10 +43,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -382,18 +382,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -382,18 +382,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale. # Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"): name = maybe_remap_kv_scale_name(name, params_dict)
remapped_kv_scale_name = name.replace( if name is None:
".kv_scale", ".attn.kv_scale") continue
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment