Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
4f65af0e
"docs/vscode:/vscode.git/clone" did not exist on "b83bdce42abb60e72314bb8507c710e6f649dfb1"
Unverified
Commit
4f65af0e
authored
Jan 30, 2024
by
Vladimir
Committed by
GitHub
Jan 30, 2024
Browse files
Add swap_blocks unit tests (#2616)
parent
d79ced32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
0 deletions
+68
-0
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+68
-0
No files found.
tests/kernels/test_cache.py
View file @
4f65af0e
...
@@ -3,8 +3,11 @@ import random
...
@@ -3,8 +3,11 @@ import random
import
pytest
import
pytest
import
torch
import
torch
from
typing
import
Tuple
from
vllm._C
import
cache_ops
from
vllm._C
import
cache_ops
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
42
]
# Arbitrary values for testing
NUM_TOKENS
=
[
42
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
...
@@ -153,3 +156,68 @@ def test_reshape_and_cache(
...
@@ -153,3 +156,68 @@ def test_reshape_and_cache(
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
@
pytest
.
mark
.
parametrize
(
"num_mappings"
,
NUM_MAPPINGS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_swap_blocks
(
kv_cache_factory
,
direction
:
Tuple
[
str
,
str
],
num_mappings
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
src_device
=
f
"
{
direction
[
0
]
}
:
{
device
}
"
if
direction
[
0
]
==
"cuda"
else
direction
[
0
]
dst_device
=
f
"
{
direction
[
1
]
}
:
{
device
}
"
if
direction
[
1
]
==
"cuda"
else
direction
[
1
]
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
# For the same device, mapping must not overlap
if
src_device
==
dst_device
:
remaining_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remaining_blocks
,
num_mappings
)
else
:
dst_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
block_mapping
=
dict
(
zip
(
src_blocks
,
dst_blocks
))
# Create the KV caches on the first device.
src_key_caches
,
src_value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
1
,
num_heads
,
head_size
,
dtype
,
seed
,
src_device
)
# Create the KV caches on the second device.
dist_key_caches
,
dist_value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
1
,
num_heads
,
head_size
,
dtype
,
seed
,
dst_device
)
src_key_caches_clone
=
src_key_caches
[
0
].
clone
()
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
# Call the swap_blocks kernel.
cache_ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping
)
cache_ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping
)
for
src
,
dst
in
block_mapping
.
items
():
assert
torch
.
allclose
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
assert
torch
.
allclose
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment