Unverified Commit a3a73ab0 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
parent 8674f988
......@@ -153,15 +153,13 @@ if __name__ == '__main__':
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
......
......@@ -323,15 +323,13 @@ if __name__ == "__main__":
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
'--kv-cache-dtype',
type=str,
choices=["auto", "fp8"],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
......
......@@ -183,13 +183,11 @@ if __name__ == '__main__':
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
args = parser.parse_args()
print(args)
......
......@@ -16,22 +16,35 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024
MODELS = [
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
"meta-llama/Meta-Llama-3-8B-Instruct",
]
EXPECTED_STRS_MAP = {
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
"auto": [
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a',
'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
],
"meta-llama/Meta-Llama-3-8B-Instruct": [
"fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system made up of several basic components that work together to enable it to',
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
]
},
"meta-llama/Meta-Llama-3-8B-Instruct": {
"auto": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
......@@ -41,6 +54,17 @@ EXPECTED_STRS_MAP = {
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
],
"fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
]
},
}
capability = torch.cuda.get_device_capability()
......@@ -52,14 +76,14 @@ fp8_not_supported = (capability <
@pytest.mark.skipif(fp8_not_supported,
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
def test_models(
example_prompts,
model_name,
) -> None:
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
model = LLM(model=model_name,
max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
enforce_eager=True,
quantization="fp8")
quantization="fp8",
kv_cache_dtype=kv_cache_dtype)
tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
......@@ -81,8 +105,8 @@ def test_models(
generations.append(outputs[0].outputs[0].text)
del model
print(generations)
expected_strs = EXPECTED_STRS_MAP[model_name]
print(model_name, kv_cache_dtype, generations)
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]
......
......@@ -7,6 +7,8 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class Attention(nn.Module):
......@@ -30,6 +32,7 @@ class Attention(nn.Module):
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
if cache_config is not None:
......@@ -40,6 +43,27 @@ class Attention(nn.Module):
block_size = 16
if num_kv_heads is None:
num_kv_heads = num_heads
# The default kv_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._kv_scale = 1.0
quant_method = quant_config.get_quant_method(
self) if quant_config else None
if quant_method is not None:
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back
# to self._kv_scale in a native float32 value after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
......@@ -57,10 +81,9 @@ class Attention(nn.Module):
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)
self._kv_scale)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
......
......@@ -355,14 +355,12 @@ class CacheConfig:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8":
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"But it may cause slight accuracy drop without scaling "
"factors. FP8_E5M2 (without scaling) is only supported on "
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria.")
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
......
......@@ -191,12 +191,11 @@ class EngineArgs:
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
......
......@@ -8,8 +8,9 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"]
......@@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig):
activation_scheme=activation_scheme)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
if isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
......@@ -251,6 +256,44 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.narrow(output, 0, 0, x.shape[0])
class Fp8KVCacheMethod(QuantizeMethodBase):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
kv_scale = layer.kv_scale.to("cpu").tolist()
if not isinstance(kv_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
layer._kv_scale = kv_scale
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.")
del layer.kv_scale
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
......
......@@ -268,7 +268,8 @@ class ArcticAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -154,7 +154,8 @@ class BaiChuanAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
quant_config=quant_config)
else:
self.rotary_emb = get_rope(
self.head_dim,
......@@ -166,7 +167,8 @@ class BaiChuanAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -111,7 +111,8 @@ class BloomAttention(nn.Module):
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -86,13 +86,12 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio,
is_neox_style=False,
)
self.attn = Attention(
self.num_heads,
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
quant_config=quant_config)
def forward(
self,
......
......@@ -177,13 +177,12 @@ class CohereAttention(nn.Module):
rope_scaling=self.rope_scaling,
is_neox_style=False,
)
self.attn = Attention(
self.num_heads,
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
quant_config=quant_config)
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),
......
......@@ -218,13 +218,12 @@ class DbrxAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_heads,
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
quant_config=quant_config)
def forward(
self,
......
......@@ -232,7 +232,8 @@ class DeepseekAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -153,7 +153,8 @@ class FalconAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
quant_config=quant_config)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
......@@ -165,13 +166,15 @@ class FalconAttention(nn.Module):
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
quant_config=quant_config)
else:
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -157,7 +157,8 @@ class GemmaAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -75,7 +75,8 @@ class GPT2Attention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -88,7 +88,8 @@ class GPTBigCodeAttention(nn.Module):
self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
......@@ -88,7 +88,8 @@ class GPTJAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment