Unverified Commit a3a73ab0 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
parent 8674f988
...@@ -89,7 +89,8 @@ class GPTNeoXAttention(nn.Module): ...@@ -89,7 +89,8 @@ class GPTNeoXAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -117,7 +117,8 @@ class InternLM2Attention(nn.Module): ...@@ -117,7 +117,8 @@ class InternLM2Attention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -105,13 +105,12 @@ class JAISAttention(nn.Module): ...@@ -105,13 +105,12 @@ class JAISAttention(nn.Module):
head_end = (tp_rank + 1) * self.num_heads head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end] alibi_slopes = alibi_slopes[head_start:head_end]
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads,
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config, cache_config=cache_config,
) quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader) default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.utils import is_hip from vllm.utils import is_hip, print_warning_once
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -119,15 +119,6 @@ class LlamaAttention(nn.Module): ...@@ -119,15 +119,6 @@ class LlamaAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_scale = 1.0
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
...@@ -155,7 +146,8 @@ class LlamaAttention(nn.Module): ...@@ -155,7 +146,8 @@ class LlamaAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window, sliding_window=sliding_window,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
...@@ -167,8 +159,7 @@ class LlamaAttention(nn.Module): ...@@ -167,8 +159,7 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
self.kv_scale)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -421,6 +412,19 @@ class LlamaForCausalLM(nn.Module): ...@@ -421,6 +412,19 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
...@@ -445,7 +449,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -445,7 +449,7 @@ class LlamaForCausalLM(nn.Module):
# scaling_factor = tensor_amax / FPtype_max # scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2 scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"): if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor layer_self_attn.attn._kv_scale = scaling_factor
else: else:
raise RuntimeError("Self attention has no KV cache scaling " raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!") "factor attribute!")
...@@ -236,7 +236,8 @@ class MiniCPMAttention(nn.Module): ...@@ -236,7 +236,8 @@ class MiniCPMAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -308,14 +308,13 @@ class MixtralAttention(nn.Module): ...@@ -308,14 +308,13 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
cache_config=cache_config, cache_config=cache_config,
) quant_config=quant_config)
def forward( def forward(
self, self,
...@@ -581,6 +580,20 @@ class MixtralForCausalLM(nn.Module): ...@@ -581,6 +580,20 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -213,14 +213,13 @@ class MixtralAttention(nn.Module): ...@@ -213,14 +213,13 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
cache_config=cache_config, cache_config=cache_config,
) quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -110,7 +110,8 @@ class MPTAttention(nn.Module): ...@@ -110,7 +110,8 @@ class MPTAttention(nn.Module):
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -96,7 +96,8 @@ class OlmoAttention(nn.Module): ...@@ -96,7 +96,8 @@ class OlmoAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling, scale=self.scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
# Attention output projection. # Attention output projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
......
...@@ -91,7 +91,8 @@ class OPTAttention(nn.Module): ...@@ -91,7 +91,8 @@ class OPTAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling, scale=self.scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -121,7 +121,8 @@ class OrionAttention(nn.Module): ...@@ -121,7 +121,8 @@ class OrionAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -110,7 +110,8 @@ class PhiAttention(nn.Module): ...@@ -110,7 +110,8 @@ class PhiAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -106,7 +106,8 @@ class QWenAttention(nn.Module): ...@@ -106,7 +106,8 @@ class QWenAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -141,7 +141,8 @@ class Qwen2Attention(nn.Module): ...@@ -141,7 +141,8 @@ class Qwen2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -241,7 +241,8 @@ class Qwen2MoeAttention(nn.Module): ...@@ -241,7 +241,8 @@ class Qwen2MoeAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -127,7 +127,8 @@ class StablelmAttention(nn.Module): ...@@ -127,7 +127,8 @@ class StablelmAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_key_value_heads, num_kv_heads=self.num_key_value_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -97,14 +97,13 @@ class Starcoder2Attention(nn.Module): ...@@ -97,14 +97,13 @@ class Starcoder2Attention(nn.Module):
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
cache_config=cache_config, cache_config=cache_config,
) quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -135,7 +135,8 @@ class XverseAttention(nn.Module): ...@@ -135,7 +135,8 @@ class XverseAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window, sliding_window=sliding_window,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
......
...@@ -31,6 +31,8 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -31,6 +31,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
"fp8": torch.uint8, "fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
} }
......
import time import time
import warnings
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
import numpy as np import numpy as np
...@@ -168,11 +169,21 @@ class ModelRunner: ...@@ -168,11 +169,21 @@ class ModelRunner:
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip(): if self.kv_cache_dtype == "fp8" and is_hip():
# Currently scaled KV cache is only enabled on ROCm # Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None: if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)): if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2)
self.model.load_kv_cache_scales( self.model.load_kv_cache_scales(
self.model_config.quantization_param_path) self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s",
self.model_config.quantization_param_path)
else: else:
raise RuntimeError( raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but " "Using FP8 KV cache and scaling factors provided but "
...@@ -183,10 +194,6 @@ class ModelRunner: ...@@ -183,10 +194,6 @@ class ModelRunner:
"Using FP8 KV cache but no scaling factors " "Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. " "provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!") "This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None:
logger.warning("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")
def save_sharded_state( def save_sharded_state(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment