Unverified Commit 2358c2bb authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add basic FP8 KV cache support (#2603)

* Add basic FP8 KV cache support

This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.

* Fix Cargo.toml
parent 68103079
...@@ -4177,7 +4177,7 @@ dependencies = [ ...@@ -4177,7 +4177,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-backends-trtllm" name = "text-generation-backends-trtllm"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
...@@ -4200,7 +4200,7 @@ dependencies = [ ...@@ -4200,7 +4200,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap 4.5.18", "clap 4.5.18",
...@@ -4220,7 +4220,7 @@ dependencies = [ ...@@ -4220,7 +4220,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",
...@@ -4238,7 +4238,7 @@ dependencies = [ ...@@ -4238,7 +4238,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"clap 4.5.18", "clap 4.5.18",
"ctrlc", "ctrlc",
...@@ -4259,7 +4259,7 @@ dependencies = [ ...@@ -4259,7 +4259,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
...@@ -4308,7 +4308,7 @@ dependencies = [ ...@@ -4308,7 +4308,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v2" name = "text-generation-router-v2"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
...@@ -4357,7 +4357,7 @@ dependencies = [ ...@@ -4357,7 +4357,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v3" name = "text-generation-router-v3"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
......
...@@ -89,6 +89,15 @@ Options: ...@@ -89,6 +89,15 @@ Options:
[env: DTYPE=] [env: DTYPE=]
[possible values: float16, bfloat16] [possible values: float16, bfloat16]
```
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
[env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2]
``` ```
## TRUST_REMOTE_CODE ## TRUST_REMOTE_CODE
```shell ```shell
......
...@@ -336,6 +336,7 @@ def launcher(event_loop): ...@@ -336,6 +336,7 @@ def launcher(event_loop):
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None,
...@@ -375,6 +376,9 @@ def launcher(event_loop): ...@@ -375,6 +376,9 @@ def launcher(event_loop):
if dtype is not None: if dtype is not None:
args.append("--dtype") args.append("--dtype")
args.append(dtype) args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None: if revision is not None:
args.append("--revision") args.append("--revision")
args.append(revision) args.append(revision)
...@@ -434,6 +438,7 @@ def launcher(event_loop): ...@@ -434,6 +438,7 @@ def launcher(event_loop):
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None,
...@@ -456,6 +461,9 @@ def launcher(event_loop): ...@@ -456,6 +461,9 @@ def launcher(event_loop):
if dtype is not None: if dtype is not None:
args.append("--dtype") args.append("--dtype")
args.append(dtype) args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None: if revision is not None:
args.append("--revision") args.append("--revision")
args.append(revision) args.append(revision)
...@@ -589,7 +597,6 @@ def generate_multi(): ...@@ -589,7 +597,6 @@ def generate_multi():
max_new_tokens: int, max_new_tokens: int,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> List[Response]: ) -> List[Response]:
import numpy as np import numpy as np
arange = np.arange(len(prompts)) arange = np.arange(len(prompts))
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 374,
"logprob": -22.96875,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"special": false,
"text": " "
},
{
"id": 128001,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}
]
import pytest
@pytest.fixture(scope="module")
def flash_llama_fp8_kv_cache_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):
await flash_llama_fp8_kv_cache_handle.health(300)
return flash_llama_fp8_kv_cache_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):
response = await flash_llama_fp8_kv_cache.generate(
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
)
assert (
response.generated_text
== " Deep learning is a subset of machine learning that is"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache_all_params(
flash_llama_fp8_kv_cache, response_snapshot
):
response = await flash_llama_fp8_kv_cache.generate(
"What is deep learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache_load(
flash_llama_fp8_kv_cache, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert (
responses[0].generated_text
== " Deep learning is a subset of machine learning that is"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"Different messages : {[r.generated_text for r in responses]}"
assert responses == response_snapshot
...@@ -301,6 +301,22 @@ impl std::fmt::Display for Dtype { ...@@ -301,6 +301,22 @@ impl std::fmt::Display for Dtype {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}
impl std::fmt::Display for KVCacheDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KVCacheDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
}
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling { enum RopeScaling {
Linear, Linear,
...@@ -402,6 +418,12 @@ struct Args { ...@@ -402,6 +418,12 @@ struct Args {
#[clap(long, env, value_enum)] #[clap(long, env, value_enum)]
dtype: Option<Dtype>, dtype: Option<Dtype>,
/// Specify the dtype for the key-value cache. When this option is not provided,
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
/// the only supported value is `fp8_e5m2` on CUDA.
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KVCacheDtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been /// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision. /// contributed in a newer revision.
...@@ -670,6 +692,7 @@ fn shard_manager( ...@@ -670,6 +692,7 @@ fn shard_manager(
quantize: Option<Quantization>, quantize: Option<Quantization>,
speculate: Option<usize>, speculate: Option<usize>,
dtype: Option<Dtype>, dtype: Option<Dtype>,
kv_cache_dtype: Option<KVCacheDtype>,
trust_remote_code: bool, trust_remote_code: bool,
uds_path: String, uds_path: String,
rank: usize, rank: usize,
...@@ -743,6 +766,11 @@ fn shard_manager( ...@@ -743,6 +766,11 @@ fn shard_manager(
shard_args.push(dtype.to_string()) shard_args.push(dtype.to_string())
} }
if let Some(kv_cache_dtype) = kv_cache_dtype {
shard_args.push("--kv-cache-dtype".to_string());
shard_args.push(kv_cache_dtype.to_string())
}
// Model optional revision // Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_args.push("--revision".to_string()); shard_args.push("--revision".to_string());
...@@ -1299,6 +1327,7 @@ fn spawn_shards( ...@@ -1299,6 +1327,7 @@ fn spawn_shards(
let otlp_service_name = args.otlp_service_name.clone(); let otlp_service_name = args.otlp_service_name.clone();
let speculate = args.speculate; let speculate = args.speculate;
let dtype = args.dtype; let dtype = args.dtype;
let kv_cache_dtype = args.kv_cache_dtype;
let trust_remote_code = args.trust_remote_code; let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port; let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
...@@ -1317,6 +1346,7 @@ fn spawn_shards( ...@@ -1317,6 +1346,7 @@ fn spawn_shards(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
rank, rank,
......
...@@ -30,6 +30,10 @@ class Dtype(str, Enum): ...@@ -30,6 +30,10 @@ class Dtype(str, Enum):
bloat16 = "bfloat16" bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e5m2 = "fp8_e5m2"
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
...@@ -38,6 +42,7 @@ def serve( ...@@ -38,6 +42,7 @@ def serve(
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
...@@ -97,6 +102,7 @@ def serve( ...@@ -97,6 +102,7 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value dtype = None if dtype is None else dtype.value
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
if dtype is not None and quantize not in { if dtype is not None and quantize not in {
None, None,
"bitsandbytes", "bitsandbytes",
...@@ -114,6 +120,7 @@ def serve( ...@@ -114,6 +120,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
max_input_tokens, max_input_tokens,
......
from text_generation_server.utils.import_utils import SYSTEM
import os import os
from text_generation_server.utils.import_utils import SYSTEM
from .common import Seqlen from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda": if SYSTEM == "cuda":
from .cuda import ( from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
) )
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import ( from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
from .ipex import ( from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
) )
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache
__all__ = [ __all__ = [
"attention", "attention",
...@@ -39,5 +42,6 @@ __all__ = [ ...@@ -39,5 +42,6 @@ __all__ = [
"reshape_and_cache", "reshape_and_cache",
"PREFILL_IN_KV_CACHE", "PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"KVCache",
"Seqlen", "Seqlen",
] ]
...@@ -355,3 +355,11 @@ else: ...@@ -355,3 +355,11 @@ else:
# have a configuration that requires flash-attention v1, which # have a configuration that requires flash-attention v1, which
# does not support block tables. # does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]
...@@ -80,3 +80,12 @@ def paged_attention( ...@@ -80,3 +80,12 @@ def paged_attention(
None, None,
) )
return out return out
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]
from typing import Tuple
import torch
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import reshape_and_cache
class KVCache:
"""
Key-value cache for attention layers.
"""
kv_cache: Tuple[torch.Tensor, torch.Tensor]
def __init__(
self,
*,
num_blocks: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
if (
dtype == torch.float8_e5m2
and ATTENTION != "flashinfer"
and SYSTEM != "cuda"
):
raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
)
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}:
self.kv_cache = (
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = (
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
)
else:
self.kv_cache = (
torch.zeros(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
@property
def key(self):
"""Get the key cache."""
return self.kv_cache[0]
@property
def value(self):
"""Get the value cache."""
return self.kv_cache[1]
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
):
"""Store the key and value at the given slots."""
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
if ATTENTION in {"flashdecoding", "flashinfer"}:
# TODO: add scale
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype == torch.float8_e5m2:
# Torch index_put does not support float8_e5m2 yet, so
# put as raw data instead.
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
key = key.view(torch.uint8)
value = value.view(torch.uint8)
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
reshape_and_cache(key, value, key_cache, value_cache, slots)
...@@ -306,3 +306,11 @@ elif ENGINE == "triton": ...@@ -306,3 +306,11 @@ elif ENGINE == "triton":
else: else:
raise RuntimeError(f"Unknown attention engine {ENGINE}") raise RuntimeError(f"Unknown attention engine {ENGINE}")
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]
...@@ -342,6 +342,7 @@ def get_model( ...@@ -342,6 +342,7 @@ def get_model(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
...@@ -403,6 +404,13 @@ def get_model( ...@@ -403,6 +404,13 @@ def get_model(
else: else:
raise RuntimeError(f"Unknown dtype {dtype}") raise RuntimeError(f"Unknown dtype {dtype}")
if kv_cache_dtype is None:
kv_cache_dtype = dtype
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")
if speculate is not None: if speculate is not None:
set_speculate(speculate) set_speculate(speculate)
else: else:
...@@ -563,6 +571,7 @@ def get_model( ...@@ -563,6 +571,7 @@ def get_model(
speculator=speculator, speculator=speculator,
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config, config_class=DeepseekV2Config,
...@@ -617,6 +626,7 @@ def get_model( ...@@ -617,6 +626,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]}, aliases={"transformer.wte.weight": ["lm_head.weight"]},
...@@ -668,6 +678,7 @@ def get_model( ...@@ -668,6 +678,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -703,6 +714,7 @@ def get_model( ...@@ -703,6 +714,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -741,6 +753,7 @@ def get_model( ...@@ -741,6 +753,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig, config_class=GPTNeoXConfig,
...@@ -774,6 +787,7 @@ def get_model( ...@@ -774,6 +787,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -797,6 +811,7 @@ def get_model( ...@@ -797,6 +811,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -836,6 +851,7 @@ def get_model( ...@@ -836,6 +851,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -859,6 +875,7 @@ def get_model( ...@@ -859,6 +875,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models # Works better for these models
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -884,6 +901,7 @@ def get_model( ...@@ -884,6 +901,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models # Works better for these models
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -910,6 +928,7 @@ def get_model( ...@@ -910,6 +928,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -934,6 +953,7 @@ def get_model( ...@@ -934,6 +953,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Dbrx works better in bfloat16. # Dbrx works better in bfloat16.
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -964,6 +984,7 @@ def get_model( ...@@ -964,6 +984,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={ aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"], "lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"], "transformer.word_embeddings.weight": ["lm_head.weight"],
...@@ -982,6 +1003,7 @@ def get_model( ...@@ -982,6 +1003,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={ aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"], "lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"], "transformer.word_embeddings.weight": ["lm_head.weight"],
...@@ -1009,6 +1031,7 @@ def get_model( ...@@ -1009,6 +1031,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -1033,6 +1056,7 @@ def get_model( ...@@ -1033,6 +1056,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -1057,6 +1081,7 @@ def get_model( ...@@ -1057,6 +1081,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -1083,6 +1108,7 @@ def get_model( ...@@ -1083,6 +1108,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
...@@ -1162,6 +1188,7 @@ def get_model( ...@@ -1162,6 +1188,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit # XXX: Extremely important to cap resolution in order to limit
...@@ -1179,6 +1206,7 @@ def get_model( ...@@ -1179,6 +1206,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models # Works better for these models
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -1197,6 +1225,7 @@ def get_model( ...@@ -1197,6 +1225,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
...@@ -1269,6 +1298,7 @@ def get_model_with_lora_adapters( ...@@ -1269,6 +1298,7 @@ def get_model_with_lora_adapters(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
adapter_to_index: Dict[str, int], adapter_to_index: Dict[str, int],
...@@ -1282,6 +1312,7 @@ def get_model_with_lora_adapters( ...@@ -1282,6 +1312,7 @@ def get_model_with_lora_adapters(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
) )
......
...@@ -28,7 +28,6 @@ from typing import Optional, List, Tuple ...@@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
...@@ -291,15 +290,15 @@ class FlashCohereAttention(torch.nn.Module): ...@@ -291,15 +290,15 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(key=key, value=value, slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -308,8 +307,8 @@ class FlashCohereAttention(torch.nn.Module): ...@@ -308,8 +307,8 @@ class FlashCohereAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -28,7 +28,6 @@ if SYSTEM != "ipex": ...@@ -28,7 +28,6 @@ if SYSTEM != "ipex":
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE, PREFILL_IN_KV_CACHE,
) )
...@@ -330,15 +329,15 @@ class DbrxAttention(torch.nn.Module): ...@@ -330,15 +329,15 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -347,8 +346,8 @@ class DbrxAttention(torch.nn.Module): ...@@ -347,8 +346,8 @@ class DbrxAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -33,7 +33,6 @@ from text_generation_server.layers.attention import ( ...@@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
...@@ -321,15 +320,15 @@ class DeepseekV2Attention(torch.nn.Module): ...@@ -321,15 +320,15 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0 value, (0, self.head_pad_size - self.value_head_size), value=0
) )
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(key=key, value=value, slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -338,8 +337,8 @@ class DeepseekV2Attention(torch.nn.Module): ...@@ -338,8 +337,8 @@ class DeepseekV2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -28,7 +28,6 @@ from typing import Optional, List, Tuple ...@@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
...@@ -253,15 +252,15 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -253,15 +252,15 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -273,8 +272,8 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -273,8 +272,8 @@ class FlashGemma2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
...@@ -28,7 +28,6 @@ from typing import Optional, List, Tuple ...@@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE, PREFILL_IN_KV_CACHE,
) )
...@@ -224,15 +223,15 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -224,15 +223,15 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -242,8 +241,8 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -242,8 +241,8 @@ class FlashGemmaAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment