Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c91b64f7
Unverified
Commit
c91b64f7
authored
Mar 10, 2025
by
Liangfu Chen
Committed by
GitHub
Mar 10, 2025
Browse files
[neuron] add reshape_and_cache (#14391)
parent
d6123170
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
126 additions
and
0 deletions
+126
-0
tests/neuron/test_cache.py
tests/neuron/test_cache.py
+83
-0
vllm/attention/ops/nki_flash_attn.py
vllm/attention/ops/nki_flash_attn.py
+43
-0
No files found.
tests/neuron/test_cache.py
0 → 100644
View file @
c91b64f7
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
vllm.attention.ops.nki_flash_attn
import
reshape_and_cache
@
pytest
.
mark
.
parametrize
(
"num_tokens, n_kv_head, d_head, num_blocks, block_size"
,
[
# Small model configuration (e.g., GPT-2 small)
(
32
,
12
,
64
,
4
,
128
),
# Typical sequence processing
(
1
,
12
,
64
,
4
,
128
),
# Single token update
(
128
,
12
,
64
,
4
,
128
),
# Longer sequence
# Medium model configuration (e.g., GPT-2 medium)
(
64
,
16
,
96
,
8
,
256
),
# Standard batch
(
256
,
16
,
96
,
8
,
256
),
# Large batch
# Large model configuration (e.g., GPT-3 style)
(
48
,
32
,
128
,
16
,
512
),
# Typical processing window
(
512
,
32
,
128
,
16
,
512
),
# Full context window
# Edge cases and stress tests
(
1024
,
8
,
32
,
32
,
32
),
# Many tokens, small heads
(
16
,
64
,
256
,
4
,
64
),
# Few tokens, many heads
(
2048
,
24
,
128
,
64
,
128
),
# Large scale test
# Minimal configurations for debugging
(
4
,
2
,
16
,
2
,
16
),
# Tiny test case
(
1
,
1
,
8
,
1
,
8
),
# Minimal possible
])
def
test_reshape_and_cache
(
num_tokens
,
n_kv_head
,
d_head
,
num_blocks
,
block_size
):
# Set random seed for reproducibility
torch
.
manual_seed
(
42
)
# Create CPU tensors for reference implementation
key_cpu
=
torch
.
randn
(
num_tokens
,
n_kv_head
,
d_head
)
/
torch
.
sqrt
(
torch
.
tensor
(
d_head
))
value_cpu
=
torch
.
randn
(
num_tokens
,
n_kv_head
,
d_head
)
/
torch
.
sqrt
(
torch
.
tensor
(
d_head
))
key_cache_cpu
=
torch
.
zeros
(
num_blocks
,
n_kv_head
,
block_size
,
d_head
)
value_cache_cpu
=
torch
.
zeros
(
num_blocks
,
n_kv_head
,
block_size
,
d_head
)
slot_mapping_cpu
=
torch
.
randperm
(
num_blocks
*
block_size
)[:
num_tokens
]
# Run reference implementation on CPU
block_indices
=
torch
.
div
(
slot_mapping_cpu
,
block_size
,
rounding_mode
=
"floor"
)
block_offsets
=
slot_mapping_cpu
%
block_size
for
i
in
range
(
num_tokens
):
block_idx
=
block_indices
[
i
]
block_offset
=
block_offsets
[
i
]
key_cache_cpu
[
block_idx
,
:,
block_offset
,
:]
=
key_cpu
[
i
]
value_cache_cpu
[
block_idx
,
:,
block_offset
,
:]
=
value_cpu
[
i
]
# Create XLA device tensors
device
=
torch
.
device
(
'xla'
)
key
=
key_cpu
.
to
(
device
)
value
=
value_cpu
.
to
(
device
)
key_cache
=
torch
.
zeros_like
(
key_cache_cpu
,
device
=
device
)
value_cache
=
torch
.
zeros_like
(
value_cache_cpu
,
device
=
device
)
slot_mapping
=
slot_mapping_cpu
.
to
(
device
)
# Run vectorized implementation on XLA device
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
# Move results back to CPU for comparison
key_cache_result
=
key_cache
.
cpu
()
value_cache_result
=
value_cache
.
cpu
()
# Assert results match
torch
.
testing
.
assert_close
(
key_cache_result
,
key_cache_cpu
,
rtol
=
1e-5
,
atol
=
1e-5
)
torch
.
testing
.
assert_close
(
value_cache_result
,
value_cache_cpu
,
rtol
=
1e-5
,
atol
=
1e-5
)
vllm/attention/ops/nki_flash_attn.py
View file @
c91b64f7
...
...
@@ -869,3 +869,46 @@ def flash_attn_varlen_nkifunc(
o
=
flash_paged_attention
[
1
,
n_kv_head
](
**
kwargs
)
return
o
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
"""
Writes key-value pairs to the KV cache at specified positions.
Args:
key (torch.Tensor): Key tensor with shape
(num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
key_cache (torch.Tensor): Key cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
value_cache (torch.Tensor): Value cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens)
Returns:
None: Updates the key_cache and value_cache tensors in-place
"""
block_size
=
key_cache
.
size
(
2
)
# Calculate indices with explicit floor division
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_offsets
=
slot_mapping
%
block_size
# Update caches using index_put_
key_cache
.
index_put_
(
(
block_indices
.
unsqueeze
(
1
),
torch
.
arange
(
key_cache
.
size
(
1
),
device
=
key
.
device
),
block_offsets
.
unsqueeze
(
1
)),
key
)
value_cache
.
index_put_
(
(
block_indices
.
unsqueeze
(
1
),
torch
.
arange
(
value_cache
.
size
(
1
),
device
=
value
.
device
),
block_offsets
.
unsqueeze
(
1
)),
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment