Unverified Commit a3a73ab0 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
parent 8674f988
......@@ -89,7 +89,8 @@ class GPTNeoXAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -117,7 +117,8 @@ class InternLM2Attention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -105,13 +105,12 @@ class JAISAttention(nn.Module):
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end]
self.attn = Attention(
self.num_heads,
self.head_dim,
scale=self.scale,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip
from vllm.utils import is_hip, print_warning_once
class LlamaMLP(nn.Module):
......@@ -119,15 +119,6 @@ class LlamaAttention(nn.Module):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_scale = 1.0
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
......@@ -155,7 +146,8 @@ class LlamaAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......@@ -167,8 +159,7 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
......@@ -421,6 +412,19 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......@@ -445,7 +449,7 @@ class LlamaForCausalLM(nn.Module):
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
......@@ -236,7 +236,8 @@ class MiniCPMAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -308,14 +308,13 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......@@ -581,6 +580,20 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -213,14 +213,13 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -110,7 +110,8 @@ class MPTAttention(nn.Module):
scaling,
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -96,7 +96,8 @@ class OlmoAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
# Attention output projection.
self.o_proj = RowParallelLinear(
......
......@@ -91,7 +91,8 @@ class OPTAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -121,7 +121,8 @@ class OrionAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -110,7 +110,8 @@ class PhiAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -106,7 +106,8 @@ class QWenAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -141,7 +141,8 @@ class Qwen2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -241,7 +241,8 @@ class Qwen2MoeAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -127,7 +127,8 @@ class StablelmAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -97,14 +97,13 @@ class Starcoder2Attention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -135,7 +135,8 @@ class XverseAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -31,6 +31,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
}
......
import time
import warnings
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
import numpy as np
......@@ -168,11 +169,21 @@ class ModelRunner:
self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently scaled KV cache is only enabled on ROCm
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2)
self.model.load_kv_cache_scales(
self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s",
self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
......@@ -183,10 +194,6 @@ class ModelRunner:
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None:
logger.warning("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")
def save_sharded_state(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment