Unverified Commit 2358c2bb authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add basic FP8 KV cache support (#2603)

* Add basic FP8 KV cache support

This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.

* Fix Cargo.toml
parent 68103079
...@@ -28,7 +28,6 @@ from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE ...@@ -28,7 +28,6 @@ from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -224,15 +223,15 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -224,15 +223,15 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(key=key, value=value, slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -241,8 +240,8 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -241,8 +240,8 @@ class FlashGPT2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -28,7 +28,6 @@ from text_generation_server.utils.import_utils import SYSTEM ...@@ -28,7 +28,6 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -186,15 +185,15 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -186,15 +185,15 @@ class FlashGPTJAttention(torch.nn.Module):
else: else:
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(key=key, value=value, slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -203,8 +202,8 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -203,8 +202,8 @@ class FlashGPTJAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -27,13 +27,12 @@ import torch.distributed ...@@ -27,13 +27,12 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -202,7 +201,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -202,7 +201,7 @@ class FlashLlamaAttention(torch.nn.Module):
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache: KVCache,
block_tables, block_tables,
slots, slots,
seqlen, seqlen,
...@@ -222,15 +221,15 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -222,15 +221,15 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -239,8 +238,8 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -239,8 +238,8 @@ class FlashLlamaAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -30,7 +30,6 @@ from text_generation_server.utils.import_utils import SYSTEM ...@@ -30,7 +30,6 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -210,17 +209,15 @@ class MistralAttention(torch.nn.Module): ...@@ -210,17 +209,15 @@ class MistralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
reshape_and_cache( kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -230,8 +227,8 @@ class MistralAttention(torch.nn.Module): ...@@ -230,8 +227,8 @@ class MistralAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -37,7 +37,6 @@ from text_generation_server.layers.attention import ( ...@@ -37,7 +37,6 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
...@@ -258,17 +257,15 @@ class MixtralAttention(torch.nn.Module): ...@@ -258,17 +257,15 @@ class MixtralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
reshape_and_cache( kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -278,8 +275,8 @@ class MixtralAttention(torch.nn.Module): ...@@ -278,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -29,7 +29,6 @@ from typing import Optional, List, Tuple ...@@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -165,15 +164,15 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -165,15 +164,15 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -182,8 +181,8 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -182,8 +181,8 @@ class FlashNeoxAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -9,7 +9,6 @@ from typing import Optional, List, Tuple ...@@ -9,7 +9,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -188,14 +187,14 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -188,14 +187,14 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Reshape key and value and cache # Reshape key and value and cache
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -204,8 +203,8 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -204,8 +203,8 @@ class FlashPhiAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -8,7 +8,6 @@ from typing import Optional, List, Tuple ...@@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -128,17 +127,15 @@ class Qwen2Attention(torch.nn.Module): ...@@ -128,17 +127,15 @@ class Qwen2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
reshape_and_cache( kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -148,8 +145,8 @@ class Qwen2Attention(torch.nn.Module): ...@@ -148,8 +145,8 @@ class Qwen2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -18,7 +18,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding ...@@ -18,7 +18,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
...@@ -200,15 +199,15 @@ class FlashRWAttention(torch.nn.Module): ...@@ -200,15 +199,15 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -217,8 +216,8 @@ class FlashRWAttention(torch.nn.Module): ...@@ -217,8 +216,8 @@ class FlashRWAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
...@@ -312,12 +311,8 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -312,12 +311,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
reshape_and_cache( kv_cache.store(
kv[:, :, 0].contiguous(), key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots
kv[:, :, 1].contiguous(),
kv_cache[0],
kv_cache[1],
slots,
) )
# Prefill # Prefill
...@@ -325,8 +320,8 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -325,8 +320,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -335,8 +330,8 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -335,8 +330,8 @@ class FlashRWLargeAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -8,7 +8,6 @@ from typing import Optional, List, Tuple ...@@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -284,17 +283,15 @@ class FlashMQAttention(torch.nn.Module): ...@@ -284,17 +283,15 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
reshape_and_cache( kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots)
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -303,8 +300,8 @@ class FlashMQAttention(torch.nn.Module): ...@@ -303,8 +300,8 @@ class FlashMQAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -29,7 +29,6 @@ from typing import Optional, List, Tuple ...@@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -233,17 +232,15 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -233,17 +232,15 @@ class Starcoder2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
reshape_and_cache( kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -253,8 +250,8 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -253,8 +250,8 @@ class Starcoder2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -46,7 +46,7 @@ from text_generation_server.models.globals import ( ...@@ -46,7 +46,7 @@ from text_generation_server.models.globals import (
TGI_WIGGLE_ROOM, TGI_WIGGLE_ROOM,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import KVCache, Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.quantization import get_loader
...@@ -937,6 +937,7 @@ class FlashCausalLM(Model): ...@@ -937,6 +937,7 @@ class FlashCausalLM(Model):
# Deepseek V2 uses different QK and V dims. # Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None, head_size: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
kv_cache_dtype: Optional[torch.dtype] = None,
): ):
self.quantize = quantize self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
...@@ -1034,6 +1035,7 @@ class FlashCausalLM(Model): ...@@ -1034,6 +1035,7 @@ class FlashCausalLM(Model):
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
...@@ -1083,61 +1085,16 @@ class FlashCausalLM(Model): ...@@ -1083,61 +1085,16 @@ class FlashCausalLM(Model):
): ):
self.kv_cache = [] self.kv_cache = []
empty_cache() empty_cache()
self.kv_cache = [
element_size = torch.tensor([], dtype=dtype).element_size() KVCache(
if SYSTEM == "ipex" and device.type == "xpu": num_blocks=num_blocks,
x = 1 num_heads=num_heads,
else: head_size=head_size,
x = BLOCK_SIZE // element_size dtype=dtype,
device=device,
if ATTENTION in {"flashdecoding", "flashinfer"}: )
self.kv_cache = [ for _ in range(num_layers)
( ]
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
else:
self.kv_cache = [
(
torch.zeros(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
...@@ -1258,7 +1215,7 @@ class FlashCausalLM(Model): ...@@ -1258,7 +1215,7 @@ class FlashCausalLM(Model):
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.dtype, self.kv_cache_dtype,
self.device, self.device,
) )
max_bt = batch.max_blocks max_bt = batch.max_blocks
...@@ -1277,7 +1234,7 @@ class FlashCausalLM(Model): ...@@ -1277,7 +1234,7 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size() dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
...@@ -1291,6 +1248,8 @@ class FlashCausalLM(Model): ...@@ -1291,6 +1248,8 @@ class FlashCausalLM(Model):
+ batch_num_blocks + batch_num_blocks
) )
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
del batch del batch
self.init_kv_cache( self.init_kv_cache(
...@@ -1298,7 +1257,7 @@ class FlashCausalLM(Model): ...@@ -1298,7 +1257,7 @@ class FlashCausalLM(Model):
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.dtype, self.kv_cache_dtype,
self.device, self.device,
) )
......
...@@ -205,6 +205,7 @@ def serve( ...@@ -205,6 +205,7 @@ def serve(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
max_input_tokens: int, max_input_tokens: int,
...@@ -217,6 +218,7 @@ def serve( ...@@ -217,6 +218,7 @@ def serve(
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
...@@ -240,6 +242,7 @@ def serve( ...@@ -240,6 +242,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
adapter_to_index, adapter_to_index,
...@@ -286,6 +289,7 @@ def serve( ...@@ -286,6 +289,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
) )
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment