Unverified Commit eab07f74 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for FP8 KV cache scales (#2628)

* Add support for FP8 KV cache scales

Since FP8 only has limited dynamic range, we can scale keys/values
before storing them into the cache (and unscale them in attention). To
avoid rescaling the cache as the absmax values change, good scales are
usually determined per layer using calibration calibration data and stored
in the checkpoint.

This change adds support for for using key-value scales and loading them
from checkpoints in the two most common formats:

- Separate per-layer `k_scale` and `v_scale` scalars.
- Per-layer `kv_scale` scalar (older format).

Currently, scales are only used with an `float8_e4m3fn` cache.

Besides adding support for key/value scales, the `fp8_quantize` function
is also extended to support quantization with a kernel vendored from
vLLM. This is slightly faster than the PyTorch implementation, but also
scales in FP32, potentially improving accuracy.

* Update FP8 KV cache test to use checkpoint with scales

* `can_scale`: check that the attention is flashinfer
parent 14a0df3a
...@@ -36,6 +36,7 @@ from text_generation_server.layers import ( ...@@ -36,6 +36,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
def load_qkv(config, prefix: str, weights, head_size, num_heads): def load_qkv(config, prefix: str, weights, head_size, num_heads):
...@@ -193,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -193,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module):
head_size=self.head_size, head_size=self.head_size,
num_heads=self.num_heads, num_heads=self.num_heads,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row( self.o_proj = load_row(
config, config,
...@@ -222,7 +224,12 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -222,7 +224,12 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size)
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -232,6 +239,7 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -232,6 +239,7 @@ class FlashGPT2Attention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -246,6 +254,7 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -246,6 +254,7 @@ class FlashGPT2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
......
...@@ -24,6 +24,7 @@ import torch.distributed ...@@ -24,6 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
...@@ -138,6 +139,7 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -138,6 +139,7 @@ class FlashGPTJAttention(torch.nn.Module):
prefix=prefix, prefix=prefix,
weights=weights, weights=weights,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row( self.o_proj = load_row(
config, config,
...@@ -184,7 +186,12 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -184,7 +186,12 @@ class FlashGPTJAttention(torch.nn.Module):
else: else:
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -194,6 +201,7 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -194,6 +201,7 @@ class FlashGPTJAttention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -208,6 +216,7 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -208,6 +216,7 @@ class FlashGPTJAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
......
...@@ -27,7 +27,10 @@ import torch.distributed ...@@ -27,7 +27,10 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import KVCache from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
...@@ -179,6 +182,8 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -179,6 +182,8 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index) self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index self.index = index
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
...@@ -224,7 +229,12 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -224,7 +229,12 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -233,6 +243,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -233,6 +243,7 @@ class FlashLlamaAttention(torch.nn.Module):
query=query, query=query,
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache, kv_cache=kv_cache,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
...@@ -248,6 +259,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -248,6 +259,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(
......
...@@ -26,6 +26,7 @@ from transformers.activations import ACT2FN ...@@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
...@@ -158,6 +159,7 @@ class MistralAttention(torch.nn.Module): ...@@ -158,6 +159,7 @@ class MistralAttention(torch.nn.Module):
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
...@@ -208,7 +210,12 @@ class MistralAttention(torch.nn.Module): ...@@ -208,7 +210,12 @@ class MistralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -218,6 +225,7 @@ class MistralAttention(torch.nn.Module): ...@@ -218,6 +225,7 @@ class MistralAttention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -233,6 +241,7 @@ class MistralAttention(torch.nn.Module): ...@@ -233,6 +241,7 @@ class MistralAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(
......
...@@ -38,6 +38,7 @@ from text_generation_server.layers.attention import ( ...@@ -38,6 +38,7 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
...@@ -213,6 +214,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -213,6 +214,7 @@ class MixtralAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
...@@ -256,7 +258,12 @@ class MixtralAttention(torch.nn.Module): ...@@ -256,7 +258,12 @@ class MixtralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -266,6 +273,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -266,6 +273,7 @@ class MixtralAttention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -281,6 +289,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -281,6 +289,7 @@ class MixtralAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
......
...@@ -38,6 +38,7 @@ from text_generation_server.layers import ( ...@@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
...@@ -130,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -130,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module):
head_size=self.head_size, head_size=self.head_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True config, prefix=f"{prefix}.dense", weights=weights, bias=True
) )
...@@ -163,7 +165,12 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -163,7 +165,12 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots) kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -173,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -173,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module):
key=qkv[:, 1], key=qkv[:, 1],
value=qkv[:, 2], value=qkv[:, 2],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -187,6 +195,7 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -187,6 +195,7 @@ class FlashNeoxAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
......
...@@ -18,6 +18,7 @@ from text_generation_server.layers import ( ...@@ -18,6 +18,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
...@@ -137,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -137,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
# in llama the dense layer is called "o_proj" and has bias=False # in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load( self.dense = TensorParallelRowLinear.load(
...@@ -186,7 +188,12 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -186,7 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Reshape key and value and cache # Reshape key and value and cache
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -194,6 +201,7 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -194,6 +201,7 @@ class FlashPhiAttention(torch.nn.Module):
query=query, query=query,
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache, kv_cache=kv_cache,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
...@@ -209,6 +217,7 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -209,6 +217,7 @@ class FlashPhiAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
......
...@@ -16,6 +16,7 @@ from text_generation_server.layers import ( ...@@ -16,6 +16,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
...@@ -84,6 +85,8 @@ class Qwen2Attention(torch.nn.Module): ...@@ -84,6 +85,8 @@ class Qwen2Attention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
...@@ -126,7 +129,12 @@ class Qwen2Attention(torch.nn.Module): ...@@ -126,7 +129,12 @@ class Qwen2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -136,6 +144,7 @@ class Qwen2Attention(torch.nn.Module): ...@@ -136,6 +144,7 @@ class Qwen2Attention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -151,6 +160,7 @@ class Qwen2Attention(torch.nn.Module): ...@@ -151,6 +160,7 @@ class Qwen2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
......
...@@ -12,6 +12,7 @@ from text_generation_server.layers import ( ...@@ -12,6 +12,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
...@@ -158,6 +159,7 @@ class FlashRWAttention(torch.nn.Module): ...@@ -158,6 +159,7 @@ class FlashRWAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=config.bias, bias=config.bias,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
...@@ -198,7 +200,12 @@ class FlashRWAttention(torch.nn.Module): ...@@ -198,7 +200,12 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -208,6 +215,7 @@ class FlashRWAttention(torch.nn.Module): ...@@ -208,6 +215,7 @@ class FlashRWAttention(torch.nn.Module):
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -222,6 +230,7 @@ class FlashRWAttention(torch.nn.Module): ...@@ -222,6 +230,7 @@ class FlashRWAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
...@@ -276,6 +285,7 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -276,6 +285,7 @@ class FlashRWLargeAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=config.bias, bias=config.bias,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
...@@ -311,7 +321,10 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -311,7 +321,10 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
kv_cache.store( kv_cache.store(
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots key=kv[:, :, 0].contiguous(),
value=kv[:, :, 1].contiguous(),
slots=slots,
kv_scales=self.kv_scales,
) )
# Prefill # Prefill
...@@ -322,6 +335,7 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -322,6 +335,7 @@ class FlashRWLargeAttention(torch.nn.Module):
key=kv[:, :, 0], key=kv[:, :, 0],
value=kv[:, :, 1], value=kv[:, :, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -336,6 +350,7 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -336,6 +350,7 @@ class FlashRWLargeAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense( return self.dense(
......
...@@ -17,6 +17,7 @@ from text_generation_server.layers import ( ...@@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
...@@ -257,6 +258,7 @@ class FlashMQAttention(torch.nn.Module): ...@@ -257,6 +258,7 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row( self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.zeros( self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device self.num_heads, dtype=torch.int32, device=weights.device
) )
...@@ -282,7 +284,12 @@ class FlashMQAttention(torch.nn.Module): ...@@ -282,7 +284,12 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots) kv_cache.store(
key=key_value[:, 0],
value=key_value[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -292,6 +299,7 @@ class FlashMQAttention(torch.nn.Module): ...@@ -292,6 +299,7 @@ class FlashMQAttention(torch.nn.Module):
key=key_value[:, 0], key=key_value[:, 0],
value=key_value[:, 1], value=key_value[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -306,6 +314,7 @@ class FlashMQAttention(torch.nn.Module): ...@@ -306,6 +314,7 @@ class FlashMQAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
......
...@@ -38,6 +38,7 @@ from text_generation_server.layers import ( ...@@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
FastRMSNorm, FastRMSNorm,
...@@ -188,6 +189,7 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -188,6 +189,7 @@ class Starcoder2Attention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
...@@ -231,7 +233,12 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -231,7 +233,12 @@ class Starcoder2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
...@@ -241,6 +248,7 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -241,6 +248,7 @@ class Starcoder2Attention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
...@@ -256,6 +264,7 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -256,6 +264,7 @@ class Starcoder2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
......
...@@ -2283,6 +2283,7 @@ class FlashCausalLM(Model): ...@@ -2283,6 +2283,7 @@ class FlashCausalLM(Model):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype,
dtype=self.dtype, dtype=self.dtype,
window_left=self.sliding_window, window_left=self.sliding_window,
) )
......
...@@ -207,7 +207,9 @@ class Weights: ...@@ -207,7 +207,9 @@ class Weights:
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): def get_tensor(
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
) -> torch.Tensor:
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment