Unverified Commit 2358c2bb authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add basic FP8 KV cache support (#2603)

* Add basic FP8 KV cache support

This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.

* Fix Cargo.toml
parent 68103079
......@@ -28,7 +28,6 @@ from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
......@@ -224,15 +223,15 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=key, value=value, slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
......@@ -241,8 +240,8 @@ class FlashGPT2Attention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -28,7 +28,6 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
......@@ -186,15 +185,15 @@ class FlashGPTJAttention(torch.nn.Module):
else:
self.rotary_emb(query, key, cos, sin)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=key, value=value, slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
......@@ -203,8 +202,8 @@ class FlashGPTJAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -27,13 +27,12 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
......@@ -202,7 +201,7 @@ class FlashLlamaAttention(torch.nn.Module):
cos,
sin,
cu_seqlen_prefill,
kv_cache,
kv_cache: KVCache,
block_tables,
slots,
seqlen,
......@@ -222,15 +221,15 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......@@ -239,8 +238,8 @@ class FlashLlamaAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -30,7 +30,6 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
......@@ -210,17 +209,15 @@ class MistralAttention(torch.nn.Module):
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......@@ -230,8 +227,8 @@ class MistralAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -37,7 +37,6 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
......@@ -258,17 +257,15 @@ class MixtralAttention(torch.nn.Module):
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......@@ -278,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
......@@ -165,15 +164,15 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
qkv[:, 0],
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen,
block_tables,
self.softmax_scale,
......@@ -182,8 +181,8 @@ class FlashNeoxAttention(torch.nn.Module):
else:
attn_output = paged_attention(
qkv[:, 0],
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -9,7 +9,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
......@@ -188,14 +187,14 @@ class FlashPhiAttention(torch.nn.Module):
)
# Reshape key and value and cache
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......@@ -204,8 +203,8 @@ class FlashPhiAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
......@@ -128,17 +127,15 @@ class Qwen2Attention(torch.nn.Module):
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......@@ -148,8 +145,8 @@ class Qwen2Attention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -18,7 +18,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
Seqlen,
)
......@@ -200,15 +199,15 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......@@ -217,8 +216,8 @@ class FlashRWAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......@@ -312,12 +311,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
kv_cache[1],
slots,
kv_cache.store(
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots
)
# Prefill
......@@ -325,8 +320,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
seqlen,
block_tables,
self.softmax_scale,
......@@ -335,8 +330,8 @@ class FlashRWLargeAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
......@@ -284,17 +283,15 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......@@ -303,8 +300,8 @@ class FlashMQAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
......@@ -233,17 +232,15 @@ class Starcoder2Attention(torch.nn.Module):
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......@@ -253,8 +250,8 @@ class Starcoder2Attention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
......@@ -46,7 +46,7 @@ from text_generation_server.models.globals import (
TGI_WIGGLE_ROOM,
get_adapter_to_index,
)
from text_generation_server.layers.attention import Seqlen
from text_generation_server.layers.attention import KVCache, Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader
......@@ -937,6 +937,7 @@ class FlashCausalLM(Model):
# Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None,
skip_special_tokens: bool = True,
kv_cache_dtype: Optional[torch.dtype] = None,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
......@@ -1034,6 +1035,7 @@ class FlashCausalLM(Model):
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
......@@ -1083,61 +1085,16 @@ class FlashCausalLM(Model):
):
self.kv_cache = []
empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}:
self.kv_cache = [
(
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
else:
self.kv_cache = [
(
torch.zeros(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.kv_cache = [
KVCache(
num_blocks=num_blocks,
num_heads=num_heads,
head_size=head_size,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
......@@ -1258,7 +1215,7 @@ class FlashCausalLM(Model):
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.device,
)
max_bt = batch.max_blocks
......@@ -1277,7 +1234,7 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
......@@ -1291,6 +1248,8 @@ class FlashCausalLM(Model):
+ batch_num_blocks
)
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
del batch
self.init_kv_cache(
......@@ -1298,7 +1257,7 @@ class FlashCausalLM(Model):
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.device,
)
......
......@@ -205,6 +205,7 @@ def serve(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
max_input_tokens: int,
......@@ -217,6 +218,7 @@ def serve(
quantize: Optional[str] = None,
speculate: Optional[int] = None,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
......@@ -240,6 +242,7 @@ def serve(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
max_input_tokens,
adapter_to_index,
......@@ -286,6 +289,7 @@ def serve(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment