"tests/python/common/transforms/test_transform.py" did not exist on "7c3e1f94b9c897aad8cd50fac4dca35a4954d184"
Unverified Commit 2c615d12 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

[Feature] Support fp8 e5m2 kv cache with flashinfer (#1204)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 61bb223e
...@@ -203,7 +203,6 @@ class RadixAttention(nn.Module): ...@@ -203,7 +203,6 @@ class RadixAttention(nn.Module):
return self.decode_forward(q, k, v, input_metadata) return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) input_metadata.token_to_kv_pool.set_kv_buffer(
v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
k_cache[input_metadata.out_cache_loc] = cache_k )
v_cache[input_metadata.out_cache_loc] = cache_v
...@@ -16,7 +16,8 @@ limitations under the License. ...@@ -16,7 +16,8 @@ limitations under the License.
"""Memory pool.""" """Memory pool."""
import logging import logging
from typing import List, Union from abc import ABC, abstractmethod
from typing import List, Tuple, Union
import torch import torch
...@@ -52,14 +53,21 @@ class ReqToTokenPool: ...@@ -52,14 +53,21 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size)) self.free_slots = list(range(self.size))
class BaseTokenToKVPool: class BaseTokenToKVPool(ABC):
"""A memory pool that maps a token to its kv cache locations""" """A memory pool that maps a token to its kv cache locations"""
def __init__( def __init__(
self, self,
size: int, size: int,
dtype: torch.dtype,
): ):
self.size = size self.size = size
self.dtype = dtype
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
# We also add one slot. This slot is used for writing dummy output from padded tokens. # We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
...@@ -112,6 +120,28 @@ class BaseTokenToKVPool: ...@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
# We also add one slot. This slot is used for writing dummy output from padded tokens. # We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = False self.mem_state[0] = False
@abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abstractmethod
def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()
class MHATokenToKVPool(BaseTokenToKVPool): class MHATokenToKVPool(BaseTokenToKVPool):
...@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_dim: int, head_dim: int,
layer_num: int, layer_num: int,
): ):
super().__init__(size) super().__init__(size, dtype)
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
self.k_buffer = [ self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
)
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.v_buffer = [ self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
)
for _ in range(layer_num) for _ in range(layer_num)
] ]
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id] return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id] return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int): def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id] return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if cache_v.dtype != self.dtype:
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
else:
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
class MLATokenToKVPool(BaseTokenToKVPool): class MLATokenToKVPool(BaseTokenToKVPool):
...@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool): ...@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
qk_rope_head_dim: int, qk_rope_head_dim: int,
layer_num: int, layer_num: int,
): ):
super().__init__(size) super().__init__(size, dtype)
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.kv_buffer = [ self.kv_buffer = [
torch.empty( torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim), (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=dtype, dtype=self.store_dtype,
device="cuda", device="cuda",
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id].view(self.dtype)
return self.kv_buffer[layer_id] return self.kv_buffer[layer_id]
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
return self.kv_buffer[layer_id][..., : self.kv_lora_rank] return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
def get_kv_buffer(self, layer_id: int): def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
else:
self.kv_buffer[layer_id][loc] = cache_k
...@@ -315,6 +315,8 @@ def update_flashinfer_indices( ...@@ -315,6 +315,8 @@ def update_flashinfer_indices(
num_kv_heads, num_kv_heads,
head_dim, head_dim,
1, 1,
data_type=model_runner.kv_cache_dtype,
q_data_type=model_runner.dtype,
) )
else: else:
# extend part # extend part
...@@ -393,6 +395,8 @@ def update_flashinfer_indices( ...@@ -393,6 +395,8 @@ def update_flashinfer_indices(
num_kv_heads, num_kv_heads,
head_dim, head_dim,
1, 1,
data_type=model_runner.kv_cache_dtype,
q_data_type=model_runner.dtype,
) )
else: else:
# extend part # extend part
......
...@@ -311,7 +311,7 @@ class ModelRunner: ...@@ -311,7 +311,7 @@ class ModelRunner:
cell_size = ( cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers * self.model_config.num_hidden_layers
* torch._utils._element_size(self.dtype) * torch._utils._element_size(self.kv_cache_dtype)
) )
else: else:
cell_size = ( cell_size = (
...@@ -319,7 +319,7 @@ class ModelRunner: ...@@ -319,7 +319,7 @@ class ModelRunner:
* self.model_config.head_dim * self.model_config.head_dim
* self.model_config.num_hidden_layers * self.model_config.num_hidden_layers
* 2 * 2
* torch._utils._element_size(self.dtype) * torch._utils._element_size(self.kv_cache_dtype)
) )
rest_memory = available_gpu_memory - total_gpu_memory * ( rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static 1 - self.mem_fraction_static
...@@ -333,6 +333,21 @@ class ModelRunner: ...@@ -333,6 +333,21 @@ class ModelRunner:
max_num_reqs: int = None, max_num_reqs: int = None,
max_total_tokens: int = None, max_total_tokens: int = None,
): ):
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
logger.warning(
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
)
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = torch.float8_e5m2
else:
raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
)
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None: if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens: if max_total_tokens > self.max_total_num_tokens:
...@@ -369,7 +384,7 @@ class ModelRunner: ...@@ -369,7 +384,7 @@ class ModelRunner:
): ):
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
dtype=self.dtype, dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank, kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim, qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
...@@ -380,7 +395,7 @@ class ModelRunner: ...@@ -380,7 +395,7 @@ class ModelRunner:
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
dtype=self.dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size), head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
......
...@@ -33,6 +33,7 @@ class ServerArgs: ...@@ -33,6 +33,7 @@ class ServerArgs:
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
load_format: str = "auto" load_format: str = "auto"
dtype: str = "auto" dtype: str = "auto"
kv_cache_dtype: str = "auto"
trust_remote_code: bool = True trust_remote_code: bool = True
context_length: Optional[int] = None context_length: Optional[int] = None
quantization: Optional[str] = None quantization: Optional[str] = None
...@@ -196,6 +197,13 @@ class ServerArgs: ...@@ -196,6 +197,13 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.\n' '* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.', '* "float32" for FP32 precision.',
) )
parser.add_argument(
"--kv-cache-dtype",
type=str,
default=ServerArgs.kv_cache_dtype,
choices=["auto", "fp8_e5m2"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
)
parser.add_argument( parser.add_argument(
"--trust-remote-code", "--trust-remote-code",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment