Unverified Commit da971ec7 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Model] Add FP8 kv cache for Qwen2 (#5656)

parent 3eea7488
......@@ -46,6 +46,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once
class Qwen2MLP(nn.Module):
......@@ -375,6 +376,19 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment