"src/vscode:/vscode.git/clone" did not exist on "44a85546e3608e170f466dec15671bdf53fc3a88"
Unverified Commit 0bb0f763 authored by bjmsong's avatar bjmsong Committed by GitHub
Browse files

Support FP8 E4M3 KV Cache (#2786)


Co-authored-by: default avatarroot <bjmsong@126.com>
parent 85b2e057
...@@ -353,7 +353,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -353,7 +353,9 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward( o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
...@@ -362,6 +364,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -362,6 +364,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling, sm_scale=layer.scaling,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
) )
else: else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
...@@ -387,7 +391,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -387,7 +391,9 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2) o, _ = merge_state(o1, s1, o2, s2)
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
...@@ -412,13 +418,17 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -412,13 +418,17 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = decode_wrapper.forward( o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......
...@@ -47,6 +47,8 @@ class RadixAttention(nn.Module): ...@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1 self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention self.is_cross_attention = is_cross_attention
self.k_scale = 1.0
self.v_scale = 1.0
def forward( def forward(
self, self,
......
...@@ -109,8 +109,8 @@ class BaseTokenToKVPool: ...@@ -109,8 +109,8 @@ class BaseTokenToKVPool:
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
if dtype == torch.float8_e5m2: if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8 self.store_dtype = torch.uint8
else: else:
self.store_dtype = dtype self.store_dtype = dtype
...@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
): ):
layer_id = layer.layer_id layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype) cache_k = (cache_k / k_scale).to(self.dtype)
cache_v = cache_v.to(self.dtype) cache_v = (cache_v / v_scale).to(self.dtype)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
......
...@@ -54,6 +54,7 @@ from sglang.srt.utils import ( ...@@ -54,6 +54,7 @@ from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
init_custom_process_group, init_custom_process_group,
is_cuda,
is_hip, is_hip,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
...@@ -277,6 +278,29 @@ class ModelRunner: ...@@ -277,6 +278,29 @@ class ModelRunner:
device_config=DeviceConfig(self.device), device_config=DeviceConfig(self.device),
) )
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(
self.server_args.quantization_param_path
)
logger.info(
"Loaded KV cache scaling factors from %s",
self.server_args.quantization_param_path,
)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__,
)
else:
logger.warning(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
# Parse other args # Parse other args
self.sliding_window_size = ( self.sliding_window_size = (
self.model.get_attention_sliding_window_size() self.model.get_attention_sliding_window_size()
...@@ -516,6 +540,9 @@ class ModelRunner: ...@@ -516,6 +540,9 @@ class ModelRunner:
self.kv_cache_dtype = torch.float8_e5m2fnuz self.kv_cache_dtype = torch.float8_e5m2fnuz
else: else:
self.kv_cache_dtype = torch.float8_e5m2 self.kv_cache_dtype = torch.float8_e5m2
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
if is_cuda():
self.kv_cache_dtype = torch.float8_e4m3fn
else: else:
raise ValueError( raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
......
...@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -299,6 +303,30 @@ class LlamaModel(nn.Module): ...@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
...@@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
class Phi3ForCausalLM(LlamaForCausalLM): class Phi3ForCausalLM(LlamaForCausalLM):
pass pass
......
...@@ -32,6 +32,7 @@ from sglang.srt.utils import ( ...@@ -32,6 +32,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_ipv6, is_ipv6,
is_port_available, is_port_available,
nullable_str,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -47,6 +48,7 @@ class ServerArgs: ...@@ -47,6 +48,7 @@ class ServerArgs:
trust_remote_code: bool = True trust_remote_code: bool = True
dtype: str = "auto" dtype: str = "auto"
kv_cache_dtype: str = "auto" kv_cache_dtype: str = "auto"
quantization_param_path: nullable_str = None
quantization: Optional[str] = None quantization: Optional[str] = None
context_length: Optional[int] = None context_length: Optional[int] = None
device: str = "cuda" device: str = "cuda"
...@@ -350,8 +352,17 @@ class ServerArgs: ...@@ -350,8 +352,17 @@ class ServerArgs:
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
default=ServerArgs.kv_cache_dtype, default=ServerArgs.kv_cache_dtype,
choices=["auto", "fp8_e5m2"], choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
)
parser.add_argument(
"--quantization-param-path",
type=nullable_str,
default=None,
help="Path to the JSON file containing the KV cache "
"scaling factors. This should generally be supplied, when "
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. ",
) )
parser.add_argument( parser.add_argument(
"--quantization", "--quantization",
......
...@@ -1375,3 +1375,9 @@ def debug_timing(func): ...@@ -1375,3 +1375,9 @@ def debug_timing(func):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
def nullable_str(val: str):
if not val or val == "None":
return None
return val
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 1,
"1": 1,
"2": 1,
"3": 1,
"4": 1,
"5": 1,
"6": 1,
"7": 1,
"8": 1,
"9": 1,
"10": 1,
"11": 1,
"12": 1,
"13": 1,
"14": 1,
"15": 1,
"16": 1,
"17": 1,
"18": 1,
"19": 1,
"20": 1,
"21": 1,
"22": 1,
"23": 1,
"24": 1,
"25": 1,
"26": 1,
"27": 1,
"28": 1,
"29": 1,
"30": 1,
"31": 1
}
}
}
}
import os
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestFp8Kvcache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
dirpath = os.path.dirname(__file__)
config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json")
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--kv-cache-dtype",
"fp8_e4m3",
"--quantization-param-path",
config_file,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.835)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment