Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
928de468
Unverified
Commit
928de468
authored
Oct 16, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 16, 2023
Browse files
Implement PagedAttention V2 (#1348)
parent
29678cd2
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
764 additions
and
139 deletions
+764
-139
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+197
-0
csrc/attention.cpp
csrc/attention.cpp
+24
-4
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+413
-71
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+5
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+54
-17
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+71
-47
No files found.
benchmarks/kernels/benchmark_paged_attention.py
0 → 100644
View file @
928de468
import
argparse
import
random
import
time
import
torch
from
vllm
import
attention_ops
NUM_BLOCKS
=
1024
PARTITION_SIZE
=
512
@
torch
.
inference_mode
()
def
main
(
version
:
str
,
num_seqs
:
int
,
context_len
:
int
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
do_profile
:
bool
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
head_mapping
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
num_queries_per_kv
)
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
context_lens
=
[
context_len
for
_
in
range
(
num_seqs
)]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Create the KV cache.
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
key_cache
.
uniform_
(
-
scale
,
scale
)
value_cache_shape
=
(
NUM_BLOCKS
,
num_kv_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
value_cache
.
uniform_
(
-
scale
,
scale
)
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
def
run_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
elif
version
==
"v2"
:
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
else
:
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
(
num_iters
=
3
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
100
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--context-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
128
,
256
],
default
=
128
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
num_query_heads
%
args
.
num_kv_heads
!=
0
:
raise
ValueError
(
"num_query_heads must be divisible by num_kv_heads"
)
dtype_to_torch_dtype
=
{
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
}
main
(
version
=
args
.
version
,
num_seqs
=
args
.
batch_size
,
context_len
=
args
.
context_len
,
num_query_heads
=
args
.
num_query_heads
,
num_kv_heads
=
args
.
num_kv_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
use_alibi
=
args
.
use_alibi
,
dtype
=
dtype_to_torch_dtype
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
)
csrc/attention.cpp
View file @
928de468
#include <torch/extension.h>
#include <torch/extension.h>
#include <c10/util/Optional.h>
#include <c10/util/Optional.h>
void
single_query_cached_kv
_attention
(
void
paged
_attention
_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
key_cache
,
...
@@ -14,9 +14,29 @@ void single_query_cached_kv_attention(
...
@@ -14,9 +14,29 @@ void single_query_cached_kv_attention(
int
max_context_len
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
head_mapping
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
m
.
def
(
"single_query_cached_kv_attention"
,
"paged_attention_v1"
,
&
single_query_cached_kv_attention
,
&
paged_attention_v1
,
"Compute the attention between an input query and the cached key/value tensors"
);
"Compute the attention between an input query and the cached keys/values using PagedAttention."
);
m
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
}
}
csrc/attention/attention_kernels.cu
View file @
928de468
This diff is collapsed.
Click to expand it.
csrc/attention/dtype_bfloat16.cuh
View file @
928de468
...
@@ -420,6 +420,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
...
@@ -420,6 +420,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#endif
#endif
}
}
// From bfloat16 to float32.
inline
__device__
float
to_float
(
__nv_bfloat16
u
)
{
return
__bfloat162float
(
u
);
}
// Zero-out a variable.
// Zero-out a variable.
inline
__device__
void
zero
(
__nv_bfloat16
&
dst
)
{
inline
__device__
void
zero
(
__nv_bfloat16
&
dst
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
...
tests/kernels/test_attention.py
View file @
928de468
...
@@ -14,13 +14,14 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
...
@@ -14,13 +14,14 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# - 512 as a buffer
# - 512 as a buffer
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
MAX_SEQ_LEN
=
get_max_shared_memory_bytes
()
//
FLOAT32_BYTES
-
512
NUM_BLOCKS
=
128
# Arbitrary values for testing
NUM_BLOCKS
=
128
# Arbitrary values for testing
PARTITION_SIZE
=
512
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
1
,
3
,
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
16
,
32
]
USE_ALIBI
=
[
False
,
True
]
USE_ALIBI
=
[
False
,
True
]
SEEDS
=
[
0
]
SEEDS
=
[
0
]
...
@@ -96,6 +97,7 @@ def ref_single_query_cached_kv_attention(
...
@@ -96,6 +97,7 @@ def ref_single_query_cached_kv_attention(
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
@
pytest
.
mark
.
parametrize
(
"version"
,
[
"v1"
,
"v2"
])
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
@@ -103,9 +105,9 @@ def ref_single_query_cached_kv_attention(
...
@@ -103,9 +105,9 @@ def ref_single_query_cached_kv_attention(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_paged_attention
(
def
test_single_query_cached_kv_attention
(
kv_cache_factory
,
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_seqs
:
int
,
num_heads
:
Tuple
[
int
,
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
...
@@ -162,19 +164,54 @@ def test_single_query_cached_kv_attention(
...
@@ -162,19 +164,54 @@ def test_single_query_cached_kv_attention(
# Call the paged attention kernel.
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
attention_ops
.
single_query_cached_kv_attention
(
if
version
==
"v1"
:
output
,
attention_ops
.
paged_attention_v1
(
query
,
output
,
key_cache
,
query
,
value_cache
,
key_cache
,
head_mapping
,
value_cache
,
scale
,
head_mapping
,
block_tables
,
scale
,
context_lens
,
block_tables
,
block_size
,
context_lens
,
max_context_len
,
block_size
,
alibi_slopes
,
max_context_len
,
)
alibi_slopes
,
)
elif
version
==
"v2"
:
num_partitions
=
((
max_context_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
else
:
assert
False
,
f
"Unknown version:
{
version
}
"
# Run the reference implementation.
# Run the reference implementation.
ref_output
=
torch
.
empty_like
(
query
)
ref_output
=
torch
.
empty_like
(
query
)
...
...
vllm/model_executor/layers/attention.py
View file @
928de468
...
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.rotary_embedding import (
...
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding
)
RotaryEmbedding
)
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
class
PagedAttention
(
nn
.
Module
):
class
PagedAttention
(
nn
.
Module
):
...
@@ -130,6 +132,14 @@ class PagedAttention(nn.Module):
...
@@ -130,6 +132,14 @@ class PagedAttention(nn.Module):
output
.
copy_
(
out
.
squeeze
(
0
))
output
.
copy_
(
out
.
squeeze
(
0
))
return
output
return
output
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
"""Returns the slopes for the alibi attention bias.
Returns:
slopes: shape = [num_heads]
"""
return
None
def
single_query_cached_kv_attention
(
def
single_query_cached_kv_attention
(
self
,
self
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
...
@@ -137,6 +147,7 @@ class PagedAttention(nn.Module):
...
@@ -137,6 +147,7 @@ class PagedAttention(nn.Module):
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
None
:
)
->
None
:
"""PagedAttention for the generation tokens.
"""PagedAttention for the generation tokens.
...
@@ -148,21 +159,65 @@ class PagedAttention(nn.Module):
...
@@ -148,21 +159,65 @@ class PagedAttention(nn.Module):
value_cache: shape = [num_blocks, num_kv_heads, head_size,
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
block_size]
input_metadata: metadata for paged attention.
input_metadata: metadata for paged attention.
alibi_slopes: shape = [num_heads]
"""
"""
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
num_seqs
,
num_heads
,
head_size
=
query
.
shape
output
,
max_num_partitions
=
(
query
,
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
key_cache
,
_PARTITION_SIZE
)
value_cache
,
# NOTE(woosuk): We use a simple heuristic to decide whether to use
self
.
head_mapping
,
# PagedAttention V1 or V2. If the number of partitions is 1, we use
self
.
scale
,
# V1 to avoid the overhead of reduction. Also, if the number of
input_metadata
.
block_tables
,
# sequences or heads is large, we use V1 since there is enough work
input_metadata
.
context_lens
,
# to parallelize.
block_size
,
# TODO(woosuk): Tune this heuristic.
input_metadata
.
max_context_len
,
use_v1
=
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
None
,
# alibi_slopes
if
use_v1
:
)
# Run PagedAttention V1.
attention_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
attention_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -253,7 +308,7 @@ class PagedAttention(nn.Module):
...
@@ -253,7 +308,7 @@ class PagedAttention(nn.Module):
self
.
single_query_cached_kv_attention
(
self
.
single_query_cached_kv_attention
(
output
[
num_prompt_tokens
:
num_valid_tokens
],
output
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
)
value_cache
,
input_metadata
,
self
.
get_alibi_slopes
()
)
# Reshape the output tensor.
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
# NOTE(woosuk): The output tensor may include paddings.
...
@@ -431,36 +486,5 @@ class PagedAttentionWithALiBi(PagedAttention):
...
@@ -431,36 +486,5 @@ class PagedAttentionWithALiBi(PagedAttention):
start
+=
prompt_len
start
+=
prompt_len
return
output
return
output
def
single_query_cached_kv_attention
(
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
self
,
return
self
.
alibi_slopes
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
"""PagedAttention with ALiBi bias for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
self
.
alibi_slopes
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment