Unverified Commit 2ff767b5 authored by Adrian Abeyta's avatar Adrian Abeyta Committed by GitHub
Browse files

Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)


Co-authored-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: default avatarHaiShaw <hixiao@gmail.com>
Co-authored-by: default avatarAdrianAbeyta <Adrian.Abeyta@amd.com>
Co-authored-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: default avatarroot <root@gt-pla-u18-08.pla.dcgpu>
Co-authored-by: default avatarmawong-amd <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: default avatarttbachyinsda <ttbachyinsda@outlook.com>
Co-authored-by: default avatarguofangze <guofangze@kuaishou.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarjacobthebanana <50071502+jacobthebanana@users.noreply.github.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 3dcb3e8b
import argparse
import glob
import json
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
from safetensors.torch import safe_open
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
# Adapted from vllm/model_executor/weight_utils.py
# The main differences are that we add the NPZ format and simplify
# its functionality drastically for our purposes (e.g. we assume that
# the quantized model exists locally and there is no need to download it)
def _prepare_hf_weights(
quantized_model_dir: str,
load_format: str = "auto",
fall_back_to_pt: bool = True,
) -> Tuple[str, List[str], bool]:
if not os.path.isdir(quantized_model_dir):
raise FileNotFoundError(
f"The quantized model directory `{quantized_model_dir}` "
"does not exist.")
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npz":
allow_patterns = ["*.npz"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(
os.path.join(quantized_model_dir, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{quantized_model_dir}`")
return hf_weights_files, use_safetensors
# Adapted from vllm/model_executor/weight_utils.py
def _hf_tensorfile_iterator(filename: str, load_format: str,
use_safetensors: bool):
if load_format == "npz":
assert not use_safetensors
with np.load(filename) as data:
for name in data.files:
param = torch.from_numpy(data[name])
yield name, param
elif use_safetensors:
with safe_open(filename, framework="pt") as f:
for name in f.keys(): # NOQA: SIM118
param = f.get_tensor(name)
yield name, param
else:
state = torch.load(filename, map_location="cpu")
for name, param in state.items():
yield name, param
del state
torch.cuda.empty_cache()
def _kv_scales_extractor(
hf_tensor_files: Iterable[str],
use_safetensors: bool,
rank_keyword: str = "rank",
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
"""
Given a list of files containing tensor data, attempt to extract KV cache
scales from these files. Intended as a helper function taking in the output
from _prepare_hf_weights.
Args:
rank_keyword Matches the number immediately after this keyword in the
tensor filename to determine the TP rank corresponding
to said tensor file
expected_tp_size If specified, the TP size of the tensor files is checked
against this and an error is raised if they don't match.
Returns a dictionary mapping TP ranks to their relevant KV cache scales.
The per-rank scales are themselves represented as a dictionary of layer
indices to the respective per-layer scale.
"""
for char in rank_keyword:
assert not char.isdecimal(
), f"Rank keyword {rank_keyword} contains a numeric character!"
rank_scales_map = {}
for tensor_file in hf_tensor_files:
try:
rank_idx = tensor_file.find(rank_keyword)
if rank_idx != -1:
start_idx = rank_idx + len(rank_keyword)
stop_idx = start_idx
while stop_idx < len(
tensor_file) and tensor_file[stop_idx].isdecimal():
stop_idx += 1
if stop_idx == start_idx:
raise RuntimeError("Did not find rank # in filename.")
rank = int(tensor_file[start_idx:stop_idx])
elif len(hf_tensor_files) == 1:
# Since there is only one tensor file, we can assume
# that it's intended for TP rank 0
rank = 0
else:
raise RuntimeError(
f"Filename does not contain '{rank_keyword}'.")
except RuntimeError:
print("Unable to determine TP rank "
f"corresponding to file '{tensor_file}'")
raise
if rank not in rank_scales_map:
layer_scales_map = {}
rank_scales_map[rank] = layer_scales_map
else:
raise RuntimeError(
f"Tensor file '{tensor_file}' shares TP rank {rank} "
"with another tensor file.")
module_delimiter = ":" if args.load_format == "npz" else "."
for name, param in _hf_tensorfile_iterator(tensor_file,
args.load_format,
use_safetensors):
if "kv_cache_scaling_factor" in name:
nums = [
int(s) for s in name.split(module_delimiter)
if s.isdecimal()
]
assert len(
nums) == 1, f"Could not determine layer idx for {name}"
layer_idx = nums[0]
assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
f" factor corresponding to layer {layer_idx}"
try:
layer_scales_map[layer_idx] = param.item()
except RuntimeError:
print(
"This utility supports only per-tensor scalar scales "
f"for now. The tensor\n {name} = {param} \nis an "
"invalid scale factor.")
raise
if all(
len(layer_scales_map) == 0
for layer_scales_map in rank_scales_map.values()):
# Note: this is true even if the rank_scales_map is empty
print("WARNING: No KV cache scale factors found. No output saved.")
return None
empirical_tp_world_size = max(rank_scales_map.keys()) + 1
if expected_tp_size is not None:
assert expected_tp_size == empirical_tp_world_size, \
f"User expected TP world size = {expected_tp_size} " \
"from model but tool is expecting TP world size = " \
f"{empirical_tp_world_size} from model instead."
for i in range(empirical_tp_world_size):
assert i in rank_scales_map, "Expected TP world size = "\
f"{empirical_tp_world_size} but did not find KV " \
f"cache scaling factors for TP rank {i}"
print(f"Found TP world size = {empirical_tp_world_size} "
"when extracting KV cache scales!")
return rank_scales_map
def _metadata_extractor(quantized_model_dir: str,
metadata_extract_fns: \
Dict[str, Callable[[Dict[str, Any]], Any]]) \
-> Dict[str, Any]:
"""
Given a directory containing quantized model files, this function
aims to extract metadata from the JSON files within this directory.
Each JSON file is expected to represent a dictionary in JSON
format (referred to as a "JSON-dictionary"). Metadata extraction is
defined by a dictionary called metadata_extract_fns, where each
metadata field name is mapped to an extraction function.
These extraction functions are designed to take a JSON-dictionary
as their only argument and return the corresponding metadata.
While extraction functions are permitted to raise exceptions, they
should only raise a KeyError or ValueError if the metadata field
cannot be extracted from the current JSON-dictionary, yet there's
a possibility of finding it in another JSON-dictionary.
The function returns a dictionary that maps metadata fields to
their extracted data. The keys of this dictionary correspond exactly
to those in metadata_extract_fns. If any fields fail to be extracted,
their corresponding values are set to None, and a warning is printed.
"""
if not os.path.isdir(quantized_model_dir):
raise FileNotFoundError(
f"The quantized model directory `{quantized_model_dir}` "
"does not exist.")
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
result = {}
for file in metadata_files:
with open(file) as f:
try:
metadata = json.load(f)
except json.JSONDecodeError:
print(f"Could not parse `{file}` as a valid metadata file,"
" skipping it.")
continue
if not isinstance(metadata, dict):
print(f"The file `{file}` does not correspond to a "
"JSON-serialized dictionary, skipping it.")
continue
for metadata_name, extract_fn in metadata_extract_fns.items():
try:
metadata_info = extract_fn(metadata)
if metadata_name not in result:
result[metadata_name] = metadata_info
elif metadata_info != result[metadata_name]:
raise RuntimeError(
"Metadata mismatch! Originally found "
f"{metadata_name} = {result[metadata_name]} but "
f"now found {metadata_name} = {metadata_info} in "
f"`{file}`")
except KeyError:
# It is possible that a given file does not contain some
# of our selected metadata as it could be located in some
# other metadata file.
# 'EFINAE': extract_fn failure is not an error.
pass
except ValueError:
# See above.
pass
# Warn if we cannot find any of the requested metadata
for metadata_name in metadata_extract_fns:
if metadata_name not in result:
print("WARNING: Unable to find requested metadata field "
f"`{metadata_name}`, setting it to None.")
result[metadata_name] = None
return result
def main(args):
metadata_extract_fns = {
"model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
"tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
"model_dtype": lambda json_dict: json_dict["dtype"]
}
recovered_metadata = _metadata_extractor(args.quantized_model,
metadata_extract_fns)
if args.tp_size is not None:
metadata_tp_size = recovered_metadata["tp_size"]
if metadata_tp_size is not None:
assert args.tp_size == metadata_tp_size, \
f"User expected TP world size = {args.tp_size} " \
f"but found TP world size = {metadata_tp_size} from metadata!"
expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
rank_keyword = "rank"
hf_tensor_files, use_safetensors = _prepare_hf_weights(
args.quantized_model, args.load_format)
rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
rank_keyword, expected_tp_size)
# Postprocess: formatting to the current schema. Consider pulling it
# out into a dedicated function should it ever become more complicated.
rank_scales_map = {
rank: {k: scale[k]
for k in sorted(scale.keys())}
for rank, scale in rank_scales_map.items()
}
# TODO: Expand this with activation and weights scaling factors when
# they are used in the future
schema = QuantParamSchema(
model_type=recovered_metadata["model_type"],
kv_cache={
"dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
recovered_metadata["model_dtype"]),
"scaling_factor":
rank_scales_map
},
)
if args.output_dir is None:
output_file = os.path.join(args.quantized_model, args.output_name)
else:
if not os.path.isdir(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
output_file = os.path.join(args.output_dir, args.output_name)
with open(output_file, 'w') as f:
f.write(schema.model_dump_json(indent=4))
print(f"Completed! KV cache scaling factors saved to {output_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="This simple utility extracts the "
"KV cache scaling factors from a quantized HF model "
"and saves them to a JSON file compatible with later "
"use by vLLM (pass this file to the appropriate "
"runtime typically using the argument "
"--quantization-param-path <filename>). This is only used "
"if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
parser.add_argument(
"--quantized_model",
help="Specify the directory containing a single quantized HF model. "
"It is expected that the quantization format is FP8_E4M3, for use "
"on ROCm (AMD GPU).",
required=True)
parser.add_argument(
"--load_format",
help="Optionally specify the format of the model's tensor files "
"containing the KV cache scaling factors.",
choices=["auto", "safetensors", "npz", "pt"],
default="auto")
parser.add_argument(
"--output_dir",
help="Optionally specify the output directory. By default the "
"KV cache scaling factors will be saved in the model directory, "
"however you can override this behavior here.",
default=None)
parser.add_argument(
"--output_name",
help="Optionally specify the output filename.",
# TODO: Change this once additional scaling factors are enabled
default="kv_cache_scales.json")
parser.add_argument(
"--tp_size",
help="Optionally specify the tensor-parallel (TP) size that the "
"quantized model should correspond to. If specified, during KV "
"cache scaling factor extraction the observed TP size will be "
"checked against this and an error will be raised if there is "
"a mismatch. If not specified, the quantized model's expected "
"TP size is instead inferred from the largest TP rank observed. "
"The expected TP size is cross-checked against the TP ranks "
"observed in the quantized model and an error is raised if any "
"discrepancies are found.",
default=None,
type=int)
args = parser.parse_args()
main(args)
### Quantizer Utilities
`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM:
`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py`
### Prerequisite
#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later
`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo`
#### AMMO Download (code and docs)
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz`
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz`
### Usage
#### Run on H100 system for speed if FP8; number of GPUs depends on the model size
#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache:
`python quantize.py --model_dir ./ll2-7b --dtype float16 --qformat fp8 --kv_cache_dtype fp8 --output_dir ./ll2_7b_fp8 --calib_size 512 --tp_size 1`
Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference)
```
# ll ./ll2_7b_fp8/
total 19998244
drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./
drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../
-rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json
-rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz
-rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors
#
```
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from examples/quantization/hf_ptq.py
"""
import argparse
import copy
import json
import random
import time
import ammo.torch.quantization as atq
import numpy as np
import torch
from ammo.torch.export import export_model_config
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
RAND_SEED = 1234
MAX_SEQ_LEN = 2048
EMPTY_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"enable": False,
},
"*input_quantizer": {
"enable": False
},
"*lm_head*": {
"enable": False
},
"*output_layer*": {
"enable": False
},
"default": {
"enable": False
},
},
"algorithm": "max",
}
KV_CACHE_CFG = {
"*.query_key_value.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.Wqkv.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.W_pack.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.c_attn.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.k_proj.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.v_proj.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
}
QUANT_CFG_CHOICES = {
"int8_sq": atq.INT8_SMOOTHQUANT_CFG,
"fp8": atq.FP8_DEFAULT_CFG,
"int4_awq": atq.INT4_AWQ_CFG,
"w4a8_awq": atq.W4A8_AWQ_BETA_CFG,
"int8_wo": EMPTY_CFG,
"int4_wo": EMPTY_CFG,
"full_prec": EMPTY_CFG,
}
MODEL_NAME_PATTERN_MAP = {
"GPT2": "gpt2",
"Xverse": "llama",
"Llama": "llama",
"Mistral": "llama",
"GPTJ": "gptj",
"FalconForCausalLM": "falcon",
"RWForCausalLM": "falcon",
"baichuan": "baichuan",
"MPT": "mpt",
"Bloom": "bloom",
"ChatGLM": "chatglm",
"QWen": "qwen",
}
def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
print(f"Initializing tokenizer from {ckpt_path}")
tokenizer = AutoTokenizer.from_pretrained(
ckpt_path,
model_max_length=max_seq_len,
padding_side="left",
trust_remote_code=True,
)
if model_type and model_type == "qwen":
# qwen use token id 151643 as pad and eos tokens
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
# can't set attribute 'pad_token' for "<unk>"
if tokenizer.pad_token != "<unk>":
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.pad_token is not None, f"Pad token for {model_type} cannot be set!"
return tokenizer
def get_model(ckpt_path, dtype="fp16", device="cuda"):
print(f"Initializing model from {ckpt_path}")
if dtype == "bf16" or dtype == "bfloat16":
dtype = torch.bfloat16
elif dtype == "fp16" or dtype == "float16":
dtype = torch.float16
elif dtype == "fp32" or dtype == "float32":
dtype = torch.float32
else:
raise NotImplementedError(f"Unknown dtype {dtype}")
# model_kwargs = {"torch_dtype": dtype}
model_kwargs = {"torch_dtype": "auto"}
model = AutoModelForCausalLM.from_pretrained(ckpt_path,
device_map="auto",
**model_kwargs,
trust_remote_code=True)
model.eval()
model_dtype = next(model.parameters()).dtype
if dtype != model_dtype:
print(
f"[TensorRT-LLM][WARNING] The manually set model data type is {dtype}, "
f"but the data type of the HuggingFace model is {model_dtype}.")
return model
def get_model_type(model):
for k, v in MODEL_NAME_PATTERN_MAP.items():
if k.lower() in type(model).__name__.lower():
return v
return None
def get_calib_dataloader(data="cnn_dailymail",
tokenizer=None,
batch_size=1,
calib_size=512,
block_size=512,
device=None):
print("Loading calibration dataset")
if data == "pileval":
dataset = load_dataset(
"json",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
dataset = dataset["text"][:calib_size]
elif data == "cnn_dailymail":
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
dataset = dataset["article"][:calib_size]
else:
raise NotImplementedError
batch_encoded = tokenizer.batch_encode_plus(dataset,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=block_size)
if device:
batch_encoded = batch_encoded.to(device)
batch_encoded = batch_encoded["input_ids"]
calib_dataloader = DataLoader(batch_encoded,
batch_size=batch_size,
shuffle=False)
return calib_dataloader
def quantize_model(model, quant_cfg, calib_dataloader=None):
def calibrate_loop():
if calib_dataloader is None:
return
"""Adjusts weights and scaling factors based on selected algorithms."""
for idx, data in enumerate(calib_dataloader):
print(f"Calibrating batch {idx}")
model(data)
print("Starting quantization...")
start_time = time.time()
atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
end_time = time.time()
print("Quantization done. Total time used: {:.2f} s.".format(end_time -
start_time))
return model
def main(args):
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for inference.")
random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
model = get_model(args.model_dir, args.dtype, args.device)
model_type = get_model_type(model)
tokenizer = get_tokenizer(args.model_dir, model_type=model_type)
if args.qformat in ["full_prec", "int8_wo", "int4_wo"
] and args.kv_cache_dtype is None:
print(f"No quantization applied, export {args.dtype} model")
else:
if "awq" in args.qformat:
if args.calib_size > 32:
print(
f"AWQ calibration could take longer with calib_size = {args.calib_size}, Using"
" calib_size=32 instead")
args.calib_size = 32
print(
"\nAWQ calibration could take longer than other calibration methods. Please"
" increase the batch size to speed up the calibration process. Batch size can be"
" set by adding the argument --batch_size <batch_size> to the command line.\n"
)
calib_dataloader = get_calib_dataloader(
tokenizer=tokenizer,
batch_size=args.batch_size,
calib_size=args.calib_size,
device=args.device,
)
if args.qformat in QUANT_CFG_CHOICES:
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
else:
raise ValueError(
f"Unsupported quantization format: {args.qformat}")
if "awq" in args.qformat:
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
weight_quantizer = quant_cfg["quant_cfg"][
"*weight_quantizer"] # type: ignore
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
weight_quantizer["block_sizes"][-1] = args.awq_block_size
if args.kv_cache_dtype is not None:
if args.kv_cache_dtype == "fp8":
for value in KV_CACHE_CFG.values():
value.update({"num_bits": (4, 3)}) # type: ignore
quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
print(quant_cfg)
model = quantize_model(model, quant_cfg, calib_dataloader)
with torch.inference_mode():
if model_type is None:
print(
f"Unknown model type {type(model).__name__}. Continue exporting..."
)
model_type = f"unknown:{type(model).__name__}"
export_path = args.output_dir
start_time = time.time()
if args.qformat == "int4_awq" and model_type == "qwen":
torch.save(model.state_dict(), export_path)
else:
export_npz = (model_type not in [
'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan'
])
# export safetensors
export_model_config(
model,
model_type,
getattr(torch, args.dtype),
export_dir=export_path,
inference_tensor_parallel=args.tp_size,
inference_pipeline_parallel=args.pp_size,
# export_tensorrt_llm_config=(not export_npz),
export_tensorrt_llm_config=False,
export_npz=export_npz)
# Workaround for wo quantization
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
with open(f"{export_path}/config.json", 'r') as f:
tensorrt_llm_config = json.load(f)
if args.qformat == "int8_wo":
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
elif args.qformat == "int4_wo":
tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16'
else:
tensorrt_llm_config["quantization"]["quant_algo"] = None
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)
end_time = time.time()
print("Quantized model exported to {} \nTotal time used {:.2f} s.".
format(export_path, end_time - start_time))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_dir",
help="Specify where the HuggingFace model is",
required=True)
parser.add_argument("--device", default="cuda")
parser.add_argument("--dtype", help="Model data type.", default="float16")
parser.add_argument(
"--qformat",
help="Quantization format.",
default="full_prec",
choices=[
"fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo",
"full_prec"
],
)
parser.add_argument("--batch_size",
help="Batch size for calibration.",
type=int,
default=1)
parser.add_argument("--calib_size",
help="Number of samples for calibration.",
type=int,
default=512)
parser.add_argument("--output_dir", default="exported_model")
parser.add_argument("--tp_size", type=int, default=1)
parser.add_argument("--pp_size", type=int, default=1)
parser.add_argument("--awq_block_size", type=int, default=128)
parser.add_argument("--kv_cache_dtype",
help="KV Cache dtype.",
default=None,
choices=["int8", "fp8", None])
args = parser.parse_args()
main(args)
...@@ -13,6 +13,10 @@ build-backend = "setuptools.build_meta" ...@@ -13,6 +13,10 @@ build-backend = "setuptools.build_meta"
[tool.ruff] [tool.ruff]
# Allow lines to be as long as 80. # Allow lines to be as long as 80.
line-length = 80 line-length = 80
exclude = [
# External file, leaving license intact
"examples/fp8/quantizer/quantize.py"
]
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
......
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0230364128947258,
"1": 0.01979283057153225,
"2": 0.0241350457072258,
"3": 0.0308314748108387,
"4": 0.0430733822286129,
"5": 0.0370396226644516,
"6": 0.0306222103536129,
"7": 0.0357491634786129,
"8": 0.0358189195394516,
"9": 0.0443289652466774,
"10": 0.0433175228536129,
"11": 0.0416782945394516,
"12": 0.0366908498108387,
"13": 0.0432477705180645,
"14": 0.0410505048930645,
"15": 0.0457589291036129,
"16": 0.0418526791036129,
"17": 0.0432477705180645,
"18": 0.0469447560608387,
"19": 0.0514787957072258,
"20": 0.0541294664144516,
"21": 0.0587681382894516,
"22": 0.0625,
"23": 0.0585588738322258,
"24": 0.0600237175822258,
"25": 0.0588030144572258,
"26": 0.0531180277466774,
"27": 0.06396484375,
"28": 0.0603027381002903,
"29": 0.0582101047039032,
"30": 0.0625348836183548,
"31": 0.0585588738322258,
"32": 0.0582798570394516,
"33": 0.0575125589966774,
"34": 0.0590820349752903,
"35": 0.0614188089966774,
"36": 0.0631975457072258,
"37": 0.0615931935608387,
"38": 0.0601283498108387,
"39": 0.0571986623108387,
"40": 0.0670340433716774,
"41": 0.0523507259786129,
"42": 0.0547223798930645,
"43": 0.0631975457072258,
"44": 0.0663713738322258,
"45": 0.0603376142680645,
"46": 0.0652204304933548,
"47": 0.0734514519572258,
"48": 0.0693708211183548,
"49": 0.0725446492433548,
"50": 0.0627790242433548,
"51": 0.0691266804933548,
"52": 0.0688825398683548,
"53": 0.068429134786129,
"54": 0.0605119988322258,
"55": 0.0799386203289032,
"56": 0.0853097140789032,
"57": 0.0661969929933548,
"58": 0.0689871683716774,
"59": 0.0724051371216774,
"60": 0.0541643425822258,
"61": 0.0626743882894516,
"62": 0.0628487765789032,
"63": 0.0607212632894516,
"64": 0.0589076466858387,
"65": 0.0451660193502903,
"66": 0.0453055277466774,
"67": 0.0414341539144516,
"68": 0.0385044664144516,
"69": 0.0414341539144516,
"70": 0.0466308631002903,
"71": 0.0399693101644516,
"72": 0.0437011756002903,
"73": 0.0434221550822258,
"74": 0.0428989976644516,
"75": 0.0401785746216774,
"76": 0.0431082621216774,
"77": 0.0484444759786129,
"78": 0.0417829267680645,
"79": 0.0418178029358387
}
}
}
}
\ No newline at end of file
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0152239128947258,
"1": 0.0188860222697258,
"2": 0.0354178324341774,
"3": 0.0376674123108387,
"4": 0.0418526791036129,
"5": 0.0433175228536129,
"6": 0.0397600457072258,
"7": 0.0424455925822258,
"8": 0.0415387861430645,
"9": 0.0408412404358387,
"10": 0.0395856611430645,
"11": 0.0377371683716774,
"12": 0.0400739423930645,
"13": 0.040771484375,
"14": 0.0393415205180645,
"15": 0.0369001142680645,
"16": 0.03857421875,
"17": 0.0387486070394516,
"18": 0.0403180830180645,
"19": 0.0396205373108387,
"20": 0.0375627800822258,
"21": 0.0407366082072258,
"22": 0.0432477705180645,
"23": 0.0377022884786129,
"24": 0.0399693101644516,
"25": 0.0374581478536129,
"26": 0.0413295216858387,
"27": 0.0442243330180645,
"28": 0.0424804724752903,
"29": 0.0456891767680645,
"30": 0.0409109964966774,
"31": 0.0482352152466774
}
}
}
}
...@@ -32,7 +32,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256 ...@@ -32,7 +32,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True] USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0] SEEDS = [0]
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
...@@ -172,6 +172,9 @@ def test_paged_attention( ...@@ -172,6 +172,9 @@ def test_paged_attention(
device) device)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
kv_scale = 1.0
# Call the paged attention kernel. # Call the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
if version == "v1": if version == "v1":
...@@ -188,6 +191,7 @@ def test_paged_attention( ...@@ -188,6 +191,7 @@ def test_paged_attention(
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale,
) )
elif version == "v2": elif version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) // num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
...@@ -219,12 +223,13 @@ def test_paged_attention( ...@@ -219,12 +223,13 @@ def test_paged_attention(
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale,
) )
else: else:
raise AssertionError(f"Unknown version: {version}") raise AssertionError(f"Unknown version: {version}")
# Run the reference implementation. # Run the reference implementation.
if kv_cache_dtype == "fp8_e5m2": if kv_cache_dtype == "fp8":
# Convert cache data back to dtype. # Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
...@@ -232,14 +237,14 @@ def test_paged_attention( ...@@ -232,14 +237,14 @@ def test_paged_attention(
dequantized_key_cache = torch.empty(size=key_cache_shape, dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype, dtype=dtype,
device=device) device=device)
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache) cache_ops.convert_fp8(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape, dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype, dtype=dtype,
device=device) device=device)
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache) cache_ops.convert_fp8(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache value_cache = dequantized_value_cache
ref_output = torch.empty_like(query) ref_output = torch.empty_like(query)
...@@ -263,7 +268,8 @@ def test_paged_attention( ...@@ -263,7 +268,8 @@ def test_paged_attention(
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test. # so we use a relaxed tolerance for the test.
if kv_cache_dtype == "fp8_e5m2": atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5 atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
......
...@@ -5,6 +5,7 @@ import pytest ...@@ -5,6 +5,7 @@ import pytest
import torch import torch
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
...@@ -23,7 +24,7 @@ SEEDS = [0] ...@@ -23,7 +24,7 @@ SEEDS = [0]
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] KV_CACHE_DTYPE = ["auto", "fp8"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
...@@ -105,6 +106,7 @@ def test_copy_blocks( ...@@ -105,6 +106,7 @@ def test_copy_blocks(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode() @torch.inference_mode()
def test_reshape_and_cache( def test_reshape_and_cache(
kv_cache_factory, kv_cache_factory,
...@@ -116,7 +118,10 @@ def test_reshape_and_cache( ...@@ -116,7 +118,10 @@ def test_reshape_and_cache(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: str, device: str,
kv_cache_dtype: str,
) -> None: ) -> None:
if not is_hip() and kv_cache_dtype == "fp8":
pytest.skip() # This test is not tuned for e5m2 cuda precision
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -132,17 +137,33 @@ def test_reshape_and_cache( ...@@ -132,17 +137,33 @@ def test_reshape_and_cache(
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size, dtype, num_heads, head_size,
None, seed, device) kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches. # Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
cache_ops.convert_fp8(key_cache, cloned_key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
cache_ops.convert_fp8(value_cache, cloned_value_cache)
else:
cloned_key_cache = key_cache.clone() cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
# Using default kv_scale
kv_scale = 1.0
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, "auto") slot_mapping, kv_cache_dtype, kv_scale)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
cache_ops.convert_fp8(key_cache, result_key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
cache_ops.convert_fp8(value_cache, result_value_cache)
# Run the reference implementation. # Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
...@@ -156,6 +177,16 @@ def test_reshape_and_cache( ...@@ -156,6 +177,16 @@ def test_reshape_and_cache(
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache) assert torch.allclose(value_cache, cloned_value_cache)
...@@ -169,6 +200,7 @@ def test_reshape_and_cache( ...@@ -169,6 +200,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode() @torch.inference_mode()
def test_swap_blocks( def test_swap_blocks(
kv_cache_factory, kv_cache_factory,
...@@ -181,7 +213,12 @@ def test_swap_blocks( ...@@ -181,7 +213,12 @@ def test_swap_blocks(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: str, device: str,
kv_cache_dtype: str,
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and "cpu" in direction:
pytest.skip()
if not is_hip() and kv_cache_dtype == "fp8":
pytest.skip() # This test is not tuned for e5m2 cuda precision
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -202,13 +239,13 @@ def test_swap_blocks( ...@@ -202,13 +239,13 @@ def test_swap_blocks(
# Create the KV caches on the first device. # Create the KV caches on the first device.
src_key_caches, src_value_caches = kv_cache_factory( src_key_caches, src_value_caches = kv_cache_factory(
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed, num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
src_device) seed, src_device)
# Create the KV caches on the second device. # Create the KV caches on the second device.
dist_key_caches, dist_value_caches = kv_cache_factory( dist_key_caches, dist_value_caches = kv_cache_factory(
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed, num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
dst_device) seed, dst_device)
src_key_caches_clone = src_key_caches[0].clone() src_key_caches_clone = src_key_caches[0].clone()
src_value_caches_clone = src_value_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone()
...@@ -223,3 +260,40 @@ def test_swap_blocks( ...@@ -223,3 +260,40 @@ def test_swap_blocks(
dist_key_caches[0][dst].cpu()) dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(), assert torch.allclose(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu()) dist_value_caches[0][dst].cpu())
@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_conversion(
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
low = -224.0
high = 224.0
shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
cache_ops.convert_fp8(cache, cache_fp8)
converted_cache = torch.empty_like(cache)
cache_ops.convert_fp8(cache_fp8, converted_cache)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
...@@ -81,5 +81,6 @@ class AttentionImpl(ABC): ...@@ -81,5 +81,6 @@ class AttentionImpl(ABC):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -156,6 +156,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -156,6 +156,7 @@ class FlashAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -184,7 +185,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -184,7 +185,8 @@ class FlashAttentionImpl(AttentionImpl):
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype) attn_metadata.kv_cache_dtype,
kv_scale)
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
# Prompt run. # Prompt run.
...@@ -207,6 +209,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -207,6 +209,9 @@ class FlashAttentionImpl(AttentionImpl):
) )
else: else:
# prefix-enabled attention # prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output = PagedAttention.forward_prefix( output = PagedAttention.forward_prefix(
query, query,
key, key,
...@@ -233,6 +238,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -233,6 +238,7 @@ class FlashAttentionImpl(AttentionImpl):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -178,6 +178,7 @@ class XFormersImpl(AttentionImpl): ...@@ -178,6 +178,7 @@ class XFormersImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -205,7 +206,8 @@ class XFormersImpl(AttentionImpl): ...@@ -205,7 +206,8 @@ class XFormersImpl(AttentionImpl):
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype) attn_metadata.kv_cache_dtype,
kv_scale)
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
# Prompt run. # Prompt run.
...@@ -259,6 +261,9 @@ class XFormersImpl(AttentionImpl): ...@@ -259,6 +261,9 @@ class XFormersImpl(AttentionImpl):
query, key, value, attn_metadata) query, key, value, attn_metadata)
else: else:
# prefix-enabled attention # prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output = PagedAttention.forward_prefix( output = PagedAttention.forward_prefix(
query, query,
key, key,
...@@ -285,6 +290,7 @@ class XFormersImpl(AttentionImpl): ...@@ -285,6 +290,7 @@ class XFormersImpl(AttentionImpl):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -42,5 +42,7 @@ class Attention(nn.Module): ...@@ -42,5 +42,7 @@ class Attention(nn.Module):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata) return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)
...@@ -73,6 +73,7 @@ class PagedAttention: ...@@ -73,6 +73,7 @@ class PagedAttention:
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float,
) -> None: ) -> None:
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key, key,
...@@ -81,6 +82,7 @@ class PagedAttention: ...@@ -81,6 +82,7 @@ class PagedAttention:
value_cache, value_cache,
slot_mapping.flatten(), slot_mapping.flatten(),
kv_cache_dtype, kv_cache_dtype,
kv_scale,
) )
@staticmethod @staticmethod
...@@ -95,6 +97,7 @@ class PagedAttention: ...@@ -95,6 +97,7 @@ class PagedAttention:
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_scale,
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty_like(query) output = torch.empty_like(query)
...@@ -126,6 +129,7 @@ class PagedAttention: ...@@ -126,6 +129,7 @@ class PagedAttention:
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
...@@ -157,6 +161,7 @@ class PagedAttention: ...@@ -157,6 +161,7 @@ class PagedAttention:
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale,
) )
return output return output
......
...@@ -60,6 +60,11 @@ class ModelConfig: ...@@ -60,6 +60,11 @@ class ModelConfig:
output). If None, will be derived from the model. output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized. weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode. disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid. If False, we will use CUDA graph and eager execution in hybrid.
...@@ -83,6 +88,7 @@ class ModelConfig: ...@@ -83,6 +88,7 @@ class ModelConfig:
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 5, max_logprobs: int = 5,
...@@ -98,6 +104,7 @@ class ModelConfig: ...@@ -98,6 +104,7 @@ class ModelConfig:
self.code_revision = code_revision self.code_revision = code_revision
self.tokenizer_revision = tokenizer_revision self.tokenizer_revision = tokenizer_revision
self.quantization = quantization self.quantization = quantization
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs self.max_logprobs = max_logprobs
...@@ -369,21 +376,20 @@ class CacheConfig: ...@@ -369,21 +376,20 @@ class CacheConfig:
def _verify_cache_dtype(self) -> None: def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto": if self.cache_dtype == "auto":
pass pass
elif self.cache_dtype == "fp8_e5m2": elif self.cache_dtype == "fp8":
if is_hip(): if not is_hip():
raise NotImplementedError(
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
nvcc_cuda_version = get_nvcc_cuda_version() nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"): if nvcc_cuda_version < Version("11.8"):
raise ValueError( raise ValueError(
"FP8 is not supported when cuda version is lower than 11.8." "FP8 is not supported when cuda version is"
) "lower than 11.8.")
logger.info( logger.info(
"Using fp8_e5m2 data type to store kv cache. It reduces " "Using fp8 data type to store kv cache. It reduces the GPU "
"the GPU memory footprint and boosts the performance. " "memory footprint and boosts the performance. "
"But it may cause slight accuracy drop. " "But it may cause slight accuracy drop without scaling "
"Currently we only support fp8 without scaling factors and " "factors. FP8_E5M2 (without scaling) is only supported on "
"make e5m2 as a default format.") "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria.")
else: else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
......
...@@ -21,6 +21,7 @@ class EngineArgs: ...@@ -21,6 +21,7 @@ class EngineArgs:
load_format: str = 'auto' load_format: str = 'auto'
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto' kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
...@@ -159,11 +160,23 @@ class EngineArgs: ...@@ -159,11 +160,23 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--kv-cache-dtype', '--kv-cache-dtype',
type=str, type=str,
choices=['auto', 'fp8_e5m2'], choices=['auto', 'fp8'],
default=EngineArgs.kv_cache_dtype, default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model ' help='Data type for kv cache storage. If "auto", will use model '
'data type. Note FP8 is not supported when cuda version is ' 'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'lower than 11.8.') 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria. ')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when '
'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version'
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria. ')
parser.add_argument('--max-model-len', parser.add_argument('--max-model-len',
type=int, type=int,
default=EngineArgs.max_model_len, default=EngineArgs.max_model_len,
...@@ -408,8 +421,8 @@ class EngineArgs: ...@@ -408,8 +421,8 @@ class EngineArgs:
self.trust_remote_code, self.download_dir, self.load_format, self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.code_revision, self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization, self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture, self.quantization_param_path, self.enforce_eager,
self.max_logprobs) self.max_context_len_to_capture, self.max_logprobs)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
......
...@@ -97,6 +97,7 @@ class LLMEngine: ...@@ -97,6 +97,7 @@ class LLMEngine:
f"quantization={model_config.quantization}, " f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, " f"enforce_eager={model_config.enforce_eager}, "
f"kv_cache_dtype={cache_config.cache_dtype}, " f"kv_cache_dtype={cache_config.cache_dtype}, "
f"quantization_param_path={model_config.quantization_param_path}, "
f"device_config={device_config.device}, " f"device_config={device_config.device}, "
f"seed={model_config.seed})") f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
......
"""
This file contains the Pydantic schemas for various quantization-related
parameters. When a relevant quantization technique is specified, these
parameters are loaded in the form of a JSON alongside the model weights
and augment the model with additional information needed for use of that
technique. The format of this JSON should be specified by one or more
schemas contained here.
For example, when the KV cache is quantized to FP8-E4M3 (currently only
possible on ROCm), the model can be optionally augmented with KV cache
scaling factors.
"""
from typing import Dict, Optional
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
class KVCacheQuantSchema(BaseModel):
dtype: str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]
@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
assert self.dtype == "float8_e4m3fn", (
"Loaded scaling factors intended for KV cache dtype = "
f"{self.dtype} rather than float8_e4m3fn!")
return self
@model_validator(mode="after")
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_size = context["tp_size"]
num_hidden_layers = context["num_hidden_layers"]
assert len(self.scaling_factor) == tp_size, (
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
f"but LLM engine is currently running with TP size {tp_size}.")
for tp_rank, layer_maps in self.scaling_factor.items():
assert len(layer_maps) == num_hidden_layers, (
f"KV cache scales map for TP rank {tp_rank} is malformed. "
f"Expected {num_hidden_layers} layers, got "
f"{len(layer_maps)}.")
for i in range(tp_size):
assert i in self.scaling_factor, (
f"KV cache scales map for TP rank {i} not found.")
return self
@model_validator(mode="after")
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_rank = context["tp_rank"]
num_hidden_layers = context["num_hidden_layers"]
layer_scales_map = self.scaling_factor[tp_rank]
for i in range(num_hidden_layers):
assert i in layer_scales_map, (
f"Could not find KV cache scales for layer {i} in "
f"TP rank {tp_rank}.")
return self
class QuantParamSchema(BaseModel):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config = ConfigDict(protected_namespaces=())
model_type: Optional[str]
kv_cache: KVCacheQuantSchema
@model_validator(mode="after")
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
context = info.context
if context:
model_type = context.get("model_type", None)
if model_type is not None:
assert model_type == self.model_type, (
f"Model type is {model_type} but loaded "
f"scaling factors belonging to different "
f"model type {self.model_type}!")
return self
...@@ -41,11 +41,13 @@ from vllm.model_executor.layers.sampler import Sampler ...@@ -41,11 +41,13 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator,
kv_cache_scales_loader)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.utils import is_hip
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -115,6 +117,15 @@ class LlamaAttention(nn.Module): ...@@ -115,6 +117,15 @@ class LlamaAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_scale = 1.0
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
...@@ -153,7 +164,8 @@ class LlamaAttention(nn.Module): ...@@ -153,7 +164,8 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -402,3 +414,27 @@ class LlamaForCausalLM(nn.Module): ...@@ -402,3 +414,27 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
...@@ -5,7 +5,7 @@ import hashlib ...@@ -5,7 +5,7 @@ import hashlib
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple from typing import Any, Iterable, Iterator, List, Optional, Tuple
import filelock import filelock
import numpy as np import numpy as np
...@@ -18,6 +18,7 @@ from vllm.config import ModelConfig ...@@ -18,6 +18,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig, from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config) get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -275,6 +276,46 @@ def hf_model_weights_iterator( ...@@ -275,6 +276,46 @@ def hf_model_weights_iterator(
torch.cuda.empty_cache() torch.cuda.empty_cache()
def kv_cache_scales_loader(
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
Keep this function in sync with the output of examples/fp8/extract_scales.py
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct,
context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.")
except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.")
except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}")
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning("Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading.")
return []
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor """convert PySafeSlice object from safetensors to torch.Tensor
......
...@@ -25,7 +25,7 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -25,7 +25,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half, "half": torch.half,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
"fp8_e5m2": torch.uint8, "fp8": torch.uint8,
} }
...@@ -266,7 +266,7 @@ def get_nvcc_cuda_version() -> Optional[Version]: ...@@ -266,7 +266,7 @@ def get_nvcc_cuda_version() -> Optional[Version]:
return nvcc_cuda_version return nvcc_cuda_version
def _generate_random_fp8_e5m2( def _generate_random_fp8(
tensor: torch.tensor, tensor: torch.tensor,
low: float, low: float,
high: float, high: float,
...@@ -282,7 +282,7 @@ def _generate_random_fp8_e5m2( ...@@ -282,7 +282,7 @@ def _generate_random_fp8_e5m2(
from vllm._C import cache_ops from vllm._C import cache_ops
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high) tensor_tmp.uniform_(low, high)
cache_ops.convert_fp8_e5m2(tensor_tmp, tensor) cache_ops.convert_fp8(tensor_tmp, tensor)
del tensor_tmp del tensor_tmp
...@@ -311,7 +311,7 @@ def create_kv_caches_with_random( ...@@ -311,7 +311,7 @@ def create_kv_caches_with_random(
raise ValueError(f"Invalid model dtype: {model_dtype}") raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]: elif cache_dtype in ["half", "bfloat16", "float"]:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8_e5m2": elif cache_dtype == "fp8":
torch_dtype = torch.uint8 torch_dtype = torch.uint8
else: else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
...@@ -328,10 +328,10 @@ def create_kv_caches_with_random( ...@@ -328,10 +328,10 @@ def create_kv_caches_with_random(
key_cache = torch.empty(size=key_cache_shape, key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype, dtype=torch_dtype,
device=device) device=device)
if cache_dtype == 'fp8_e5m2': if cache_dtype in ["auto", "half", "bfloat16", "float"]:
_generate_random_fp8_e5m2(key_cache, -scale, scale)
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
key_cache.uniform_(-scale, scale) key_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8':
_generate_random_fp8(key_cache, -scale, scale)
else: else:
raise ValueError( raise ValueError(
f"Does not support key cache of type {cache_dtype}") f"Does not support key cache of type {cache_dtype}")
...@@ -343,10 +343,10 @@ def create_kv_caches_with_random( ...@@ -343,10 +343,10 @@ def create_kv_caches_with_random(
value_cache = torch.empty(size=value_cache_shape, value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype, dtype=torch_dtype,
device=device) device=device)
if cache_dtype == 'fp8_e5m2': if cache_dtype in ["auto", "half", "bfloat16", "float"]:
_generate_random_fp8_e5m2(value_cache, -scale, scale)
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
value_cache.uniform_(-scale, scale) value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8':
_generate_random_fp8(value_cache, -scale, scale)
else: else:
raise ValueError( raise ValueError(
f"Does not support value cache of type {cache_dtype}") f"Does not support value cache of type {cache_dtype}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment