Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2b58ca2
Unverified
Commit
d2b58ca2
authored
Apr 03, 2025
by
Liangfu Chen
Committed by
GitHub
Apr 03, 2025
Browse files
[Neuron][kernel] Fuse kv cache into a single tensor (#15911)
Signed-off-by:
Liangfu Chen
<
liangfc@amazon.com
>
parent
82e7e19a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
56 deletions
+46
-56
tests/neuron/1_core/test_cache.py
tests/neuron/1_core/test_cache.py
+3
-1
tests/neuron/1_core/test_prefix_prefill.py
tests/neuron/1_core/test_prefix_prefill.py
+5
-8
vllm/attention/ops/nki_flash_attn.py
vllm/attention/ops/nki_flash_attn.py
+38
-47
No files found.
tests/neuron/1_core/test_cache.py
View file @
d2b58ca2
...
@@ -64,9 +64,11 @@ def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
...
@@ -64,9 +64,11 @@ def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
key_cache
=
torch
.
zeros_like
(
key_cache_cpu
,
device
=
device
)
key_cache
=
torch
.
zeros_like
(
key_cache_cpu
,
device
=
device
)
value_cache
=
torch
.
zeros_like
(
value_cache_cpu
,
device
=
device
)
value_cache
=
torch
.
zeros_like
(
value_cache_cpu
,
device
=
device
)
slot_mapping
=
slot_mapping_cpu
.
to
(
device
)
slot_mapping
=
slot_mapping_cpu
.
to
(
device
)
kv_cache
=
torch
.
stack
([
key_cache
,
value_cache
])
# Run vectorized implementation on XLA device
# Run vectorized implementation on XLA device
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
reshape_and_cache
(
key
,
value
,
kv_cache
,
slot_mapping
)
key_cache
,
value_cache
=
torch
.
unbind
(
kv_cache
,
dim
=
0
)
# Move results back to CPU for comparison
# Move results back to CPU for comparison
key_cache_result
=
key_cache
.
cpu
()
key_cache_result
=
key_cache
.
cpu
()
...
...
tests/neuron/1_core/test_prefix_prefill.py
View file @
d2b58ca2
...
@@ -258,13 +258,13 @@ def sample_inputs(
...
@@ -258,13 +258,13 @@ def sample_inputs(
value
[
start_loc
:
end_loc
])
value
[
start_loc
:
end_loc
])
cur_ctx
+=
block_size
cur_ctx
+=
block_size
block_id
+=
1
block_id
+=
1
kv_cache
=
torch
.
stack
([
k_cache
,
v_cache
])
return
(
return
(
query
,
query
,
k
,
k
,
v
,
v
,
k_cache
,
kv_cache
,
v_cache
,
block_table
,
block_table
,
key
,
key
,
value
,
value
,
...
@@ -361,8 +361,7 @@ def test_contexted_kv_attention(
...
@@ -361,8 +361,7 @@ def test_contexted_kv_attention(
query
,
query
,
k_active
,
k_active
,
v_active
,
v_active
,
k_cache
,
kv_cache
,
v_cache
,
block_table
,
block_table
,
key
,
key
,
value
,
value
,
...
@@ -439,8 +438,7 @@ def test_contexted_kv_attention(
...
@@ -439,8 +438,7 @@ def test_contexted_kv_attention(
query
=
query
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
query
=
query
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
k
=
k
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
k
=
k
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
v
=
v
.
unsqueeze
(
0
).
permute
(
0
,
2
,
1
,
3
).
contiguous
()
v
=
v
.
unsqueeze
(
0
).
permute
(
0
,
2
,
1
,
3
).
contiguous
()
k_cache
=
k_cache
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
kv_cache
=
kv_cache
.
permute
(
0
,
1
,
3
,
2
,
4
).
contiguous
()
v_cache
=
v_cache
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
# transform block table
# transform block table
active_block_table
=
get_active_block_tables
(
active_block_table
=
get_active_block_tables
(
...
@@ -487,8 +485,7 @@ def test_contexted_kv_attention(
...
@@ -487,8 +485,7 @@ def test_contexted_kv_attention(
query
.
to
(
device
=
device
),
query
.
to
(
device
=
device
),
k
.
to
(
device
=
device
),
k
.
to
(
device
=
device
),
v
.
to
(
device
=
device
),
v
.
to
(
device
=
device
),
k_cache
.
to
(
device
=
device
),
kv_cache
.
to
(
device
=
device
),
v_cache
.
to
(
device
=
device
),
active_block_table
.
to
(
device
=
device
),
active_block_table
.
to
(
device
=
device
),
attn_mask
.
to
(
device
=
device
),
attn_mask
.
to
(
device
=
device
),
)
)
...
...
vllm/attention/ops/nki_flash_attn.py
View file @
d2b58ca2
...
@@ -144,8 +144,7 @@ def transform_block_tables_for_indirect_load(
...
@@ -144,8 +144,7 @@ def transform_block_tables_for_indirect_load(
def
load_kv_tile_from_cache
(
def
load_kv_tile_from_cache
(
cur_k_tile
,
cur_k_tile
,
cur_v_tile
,
cur_v_tile
,
key_cache
,
kv_cache
,
value_cache
,
block_tables
,
block_tables
,
large_k_tile_idx
,
large_k_tile_idx
,
num_blocks_per_large_tile
,
num_blocks_per_large_tile
,
...
@@ -169,8 +168,8 @@ def load_kv_tile_from_cache(
...
@@ -169,8 +168,8 @@ def load_kv_tile_from_cache(
for
load_idx
in
nl
.
affine_range
(
num_loads
):
for
load_idx
in
nl
.
affine_range
(
num_loads
):
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
nl
.
arange
(
tiled_block_size
*
B_D_SIZE
)[
None
,
:]
i_f
=
nl
.
arange
(
tiled_block_size
*
B_D_SIZE
)[
None
,
:]
loaded
=
nl
.
load
(
k
ey
_cache
[
block_tables
[
load_idx
,
i_p
,
loaded
=
nl
.
load
(
k
v
_cache
[
0
,
block_tables
[
load_idx
,
i_p
,
large_k_tile_idx
],
i_f
])
large_k_tile_idx
],
i_f
])
if
cur_k_tile
.
dtype
!=
loaded
.
dtype
:
if
cur_k_tile
.
dtype
!=
loaded
.
dtype
:
loaded
=
nl
.
copy
(
loaded
,
dtype
=
cur_k_tile
.
dtype
)
loaded
=
nl
.
copy
(
loaded
,
dtype
=
cur_k_tile
.
dtype
)
# Transpose SBUF tensor using PE
# Transpose SBUF tensor using PE
...
@@ -185,7 +184,7 @@ def load_kv_tile_from_cache(
...
@@ -185,7 +184,7 @@ def load_kv_tile_from_cache(
# load value cache
# load value cache
for
load_idx
in
nl
.
affine_range
(
num_loads
):
for
load_idx
in
nl
.
affine_range
(
num_loads
):
loaded
=
nl
.
load
(
v
alue
_cache
[
block_tables
[
load_idx
,
i_p
,
loaded
=
nl
.
load
(
k
v_cache
[
1
,
block_tables
[
load_idx
,
i_p
,
large_k_tile_idx
],
i_f
])
large_k_tile_idx
],
i_f
])
if
cur_v_tile
.
dtype
!=
loaded
.
dtype
:
if
cur_v_tile
.
dtype
!=
loaded
.
dtype
:
loaded
=
nl
.
copy
(
loaded
,
dtype
=
cur_v_tile
.
dtype
)
loaded
=
nl
.
copy
(
loaded
,
dtype
=
cur_v_tile
.
dtype
)
...
@@ -418,8 +417,7 @@ def flash_paged_attention(
...
@@ -418,8 +417,7 @@ def flash_paged_attention(
query
,
query
,
key
,
key
,
value
,
value
,
key_cache
,
kv_cache
,
value_cache
,
block_tables
,
block_tables
,
mask
,
mask
,
softmax_scale
=
None
,
softmax_scale
=
None
,
...
@@ -434,8 +432,7 @@ def flash_paged_attention(
...
@@ -434,8 +432,7 @@ def flash_paged_attention(
- query: shape (1, n_heads, d, seq_q)
- query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k)
- key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d)
- value: shape (1, n_kv_heads, seq_v, d)
- key_cache: (num_blocks, n_kv_heads, block_size, d)
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
- value_cache: (num_blocks, n_kv_heads, block_size, d)
- block_tables: (num_active_blocks, )
- block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size + seq_q)
- mask: (seq_q, num_active_blocks * block_size + seq_q)
- o: shape (1, n_heads, seq_q, d)
- o: shape (1, n_heads, seq_q, d)
...
@@ -444,7 +441,7 @@ def flash_paged_attention(
...
@@ -444,7 +441,7 @@ def flash_paged_attention(
- We use continuous batching by default, so the batch dimension is
- We use continuous batching by default, so the batch dimension is
always 1, and different requests are concatenated along sequence
always 1, and different requests are concatenated along sequence
dimension.
dimension.
- We use paged cache blocks (k
ey_cache, value
_cache) to store KV cache.
- We use paged cache blocks (k
v
_cache) to store KV cache.
IO tensor dtypes:
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
- This kernel assumes all IO tensors have the same dtype except for
...
@@ -475,15 +472,13 @@ def flash_paged_attention(
...
@@ -475,15 +472,13 @@ def flash_paged_attention(
b
,
h
,
d
,
seqlen_q
=
query
.
shape
b
,
h
,
d
,
seqlen_q
=
query
.
shape
B_D_SIZE
=
d
B_D_SIZE
=
d
n_tile_q
=
seqlen_q
//
B_P_SIZE
# since q will be loaded on tensor engine
n_tile_q
=
seqlen_q
//
B_P_SIZE
# since q will be loaded on tensor engine
num_blocks
,
k_h
,
block_size
,
_
=
k
ey
_cache
.
shape
_
,
num_blocks
,
k_h
,
block_size
,
_
=
k
v
_cache
.
shape
q_h_per_k_h
=
h
//
k_h
q_h_per_k_h
=
h
//
k_h
assert
b
==
1
,
f
"invalid batch size
{
b
=
}
"
assert
b
==
1
,
f
"invalid batch size
{
b
=
}
"
assert
d
<=
128
,
f
" we do not support head_dim > 128, got head dim
{
d
=
}
"
assert
d
<=
128
,
f
" we do not support head_dim > 128, got head dim
{
d
=
}
"
cache_shape
=
(
num_blocks
,
k_h
,
block_size
,
d
)
cache_shape
=
(
2
,
num_blocks
,
k_h
,
block_size
,
d
)
assert
(
tuple
(
key_cache
.
shape
)
==
cache_shape
assert
(
tuple
(
kv_cache
.
shape
)
==
cache_shape
),
f
"
{
key_cache
.
shape
=
}
mismatch, expect
{
cache_shape
}
"
),
f
"
{
kv_cache
.
shape
=
}
mismatch, expect
{
cache_shape
}
"
assert
(
tuple
(
value_cache
.
shape
)
==
cache_shape
),
f
"
{
value_cache
.
shape
=
}
mismatch, expect
{
cache_shape
}
"
assert
key
is
None
or
tuple
(
key
.
shape
)
==
(
assert
key
is
None
or
tuple
(
key
.
shape
)
==
(
1
,
1
,
k_h
,
k_h
,
...
@@ -580,13 +575,13 @@ def flash_paged_attention(
...
@@ -580,13 +575,13 @@ def flash_paged_attention(
head_id
=
head_id
,
head_id
=
head_id
,
)
)
# Flatten KV cache to be
2
D for loading into SBUF
# Flatten KV cache to be
3
D for loading into SBUF
new_cache_shape
=
(
new_cache_shape
=
(
2
,
num_blocks
*
k_h
*
block_size_tiling_factor
,
num_blocks
*
k_h
*
block_size_tiling_factor
,
tiled_block_size
*
d
,
tiled_block_size
*
d
,
)
)
key_cache
=
key_cache
.
reshape
(
new_cache_shape
)
kv_cache
=
kv_cache
.
reshape
(
new_cache_shape
)
value_cache
=
value_cache
.
reshape
(
new_cache_shape
)
# Global Flash Attention accumulators
# Global Flash Attention accumulators
o_buffer
=
nl
.
zeros
(
o_buffer
=
nl
.
zeros
(
...
@@ -621,8 +616,7 @@ def flash_paged_attention(
...
@@ -621,8 +616,7 @@ def flash_paged_attention(
load_kv_tile_from_cache
(
load_kv_tile_from_cache
(
cur_k_tile
=
cur_k_tile
,
cur_k_tile
=
cur_k_tile
,
cur_v_tile
=
cur_v_tile
,
cur_v_tile
=
cur_v_tile
,
key_cache
=
key_cache
,
kv_cache
=
kv_cache
,
value_cache
=
value_cache
,
block_tables
=
block_tables_sbuf
,
block_tables
=
block_tables_sbuf
,
large_k_tile_idx
=
large_k_tile_idx
,
large_k_tile_idx
=
large_k_tile_idx
,
num_blocks_per_large_tile
=
num_blocks_per_large_tile
,
num_blocks_per_large_tile
=
num_blocks_per_large_tile
,
...
@@ -821,8 +815,7 @@ def flash_attn_varlen_nkifunc(
...
@@ -821,8 +815,7 @@ def flash_attn_varlen_nkifunc(
query
,
query
,
key
,
key
,
value
,
value
,
key_cache
,
kv_cache
,
value_cache
,
block_table
,
block_table
,
attn_mask
,
attn_mask
,
n_kv_head
=
None
,
n_kv_head
=
None
,
...
@@ -838,8 +831,7 @@ def flash_attn_varlen_nkifunc(
...
@@ -838,8 +831,7 @@ def flash_attn_varlen_nkifunc(
- query: (1, n_heads, d, seq_q)
- query: (1, n_heads, d, seq_q)
- key: (1, n_kv_heads, d, seq_k)
- key: (1, n_kv_heads, d, seq_k)
- value: (1, n_kv_heads, seq_v, d)
- value: (1, n_kv_heads, seq_v, d)
- key_cache: (n_blocks, n_kv_heads, block_size, d)
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
- value_cache: (n_blocks, n_kv_heads, block_size, d)
- block_tables: (n_active_blocks, )
- block_tables: (n_active_blocks, )
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
...
@@ -849,17 +841,17 @@ def flash_attn_varlen_nkifunc(
...
@@ -849,17 +841,17 @@ def flash_attn_varlen_nkifunc(
for better DMA throughput
for better DMA throughput
"""
"""
if
n_kv_head
is
None
:
if
n_kv_head
is
None
:
n_kv_head
=
key_cache
.
shape
[
1
]
n_kv_head
=
kv_cache
.
shape
[
2
]
assert
key_cache
.
shape
[
1
]
==
n_kv_head
assert
kv_cache
.
shape
[
0
]
==
2
assert
kv_cache
.
shape
[
2
]
==
n_kv_head
if
head_size
is
None
:
if
head_size
is
None
:
head_size
=
k
ey
_cache
.
shape
[
-
1
]
head_size
=
k
v
_cache
.
shape
[
-
1
]
kwargs
=
dict
(
kwargs
=
dict
(
query
=
query
,
query
=
query
,
key
=
key
,
key
=
key
,
value
=
value
,
value
=
value
,
key_cache
=
key_cache
,
kv_cache
=
kv_cache
,
value_cache
=
value_cache
,
block_tables
=
block_table
,
block_tables
=
block_table
,
mask
=
attn_mask
,
mask
=
attn_mask
,
softmax_scale
=
1.0
/
(
head_size
**
0.5
),
softmax_scale
=
1.0
/
(
head_size
**
0.5
),
...
@@ -874,8 +866,7 @@ def flash_attn_varlen_nkifunc(
...
@@ -874,8 +866,7 @@ def flash_attn_varlen_nkifunc(
def
reshape_and_cache
(
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -886,29 +877,29 @@ def reshape_and_cache(
...
@@ -886,29 +877,29 @@ def reshape_and_cache(
(num_tokens, n_kv_head, d_head)
(num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape
value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
(num_tokens, n_kv_head, d_head)
key_cache (torch.Tensor): Key cache tensor with shape
kv_cache (torch.Tensor): Key/value cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
(2, num_blocks, n_kv_head, block_size, d_head)
value_cache (torch.Tensor): Value cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens)
with shape (num_tokens)
Returns:
Returns:
None: Updates the k
ey_cache and value
_cache tensor
s
in-place
None: Updates the k
v
_cache tensor in-place
"""
"""
block_size
=
key_cache
.
size
(
2
)
block_size
=
kv_cache
.
size
(
3
)
n_kv_head
=
key
.
size
(
1
)
# Calculate indices with explicit floor division
# Calculate indices with explicit floor division
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
slot_mapping
%
block_size
# Create the head indices tensor
head_indices
=
torch
.
arange
(
n_kv_head
,
device
=
key
.
device
)
# Update caches using index_put_
# Update caches using index_put_
key_cache
.
index_put_
(
kv_cache
.
index_put_
(
(
block_indices
.
unsqueeze
(
1
),
(
torch
.
tensor
([
0
],
device
=
key
.
device
),
block_indices
[:,
None
],
torch
.
arange
(
key_cache
.
size
(
1
),
head_indices
[
None
,
:],
block_offsets
[:,
None
]),
key
)
device
=
key
.
device
),
block_offsets
.
unsqueeze
(
1
)),
key
)
kv_cache
.
index_put_
(
value_cache
.
index_put_
(
(
torch
.
tensor
([
1
],
device
=
key
.
device
),
block_indices
[:,
None
],
(
block_indices
.
unsqueeze
(
1
),
head_indices
[
None
,
:],
block_offsets
[:,
None
]),
value
)
torch
.
arange
(
value_cache
.
size
(
1
),
device
=
value
.
device
),
block_offsets
.
unsqueeze
(
1
)),
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment