Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
fe6d09ae
Unverified
Commit
fe6d09ae
authored
Feb 06, 2024
by
Lily Liu
Committed by
GitHub
Feb 06, 2024
Browse files
[Minor] More fix of test_cache.py CI test failure (#2750)
parent
ed70c70e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
11 deletions
+16
-11
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+4
-5
vllm/utils.py
vllm/utils.py
+12
-6
No files found.
tests/kernels/test_cache.py
View file @
fe6d09ae
...
@@ -181,16 +181,15 @@ def test_swap_blocks(
...
@@ -181,16 +181,15 @@ def test_swap_blocks(
num_blocks
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
int
,
device
:
str
,
)
->
None
:
)
->
None
:
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
src_device
=
f
"
{
direction
[
0
]
}
:
{
device
}
"
if
direction
[
0
]
==
"cuda"
else
direction
[
0
]
src_device
=
device
if
direction
[
0
]
==
"cuda"
else
'cpu'
dst_device
=
f
"
{
direction
[
1
]
}
:
{
device
}
"
if
direction
[
dst_device
=
device
if
direction
[
1
]
==
"cuda"
else
'cpu'
1
]
==
"cuda"
else
direction
[
1
]
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
# For the same device, mapping must not overlap
# For the same device, mapping must not overlap
...
...
vllm/utils.py
View file @
fe6d09ae
...
@@ -258,10 +258,13 @@ def create_kv_caches_with_random(
...
@@ -258,10 +258,13 @@ def create_kv_caches_with_random(
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
device
=
device
)
device
=
device
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
if
cache_dtype
==
'fp8_e5m2'
:
key_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8_e5m2'
:
_generate_random_fp8_e5m2
(
key_cache
,
-
scale
,
scale
)
_generate_random_fp8_e5m2
(
key_cache
,
-
scale
,
scale
)
elif
torch_dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
key_cache
.
uniform_
(
-
scale
,
scale
)
else
:
raise
ValueError
(
f
"Does not support key cache of type
{
cache_dtype
}
"
)
key_caches
.
append
(
key_cache
)
key_caches
.
append
(
key_cache
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
...
@@ -270,9 +273,12 @@ def create_kv_caches_with_random(
...
@@ -270,9 +273,12 @@ def create_kv_caches_with_random(
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
device
=
device
)
device
=
device
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
if
cache_dtype
==
'fp8_e5m2'
:
value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8_e5m2'
:
_generate_random_fp8_e5m2
(
value_cache
,
-
scale
,
scale
)
_generate_random_fp8_e5m2
(
value_cache
,
-
scale
,
scale
)
elif
torch_dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
value_cache
.
uniform_
(
-
scale
,
scale
)
else
:
raise
ValueError
(
f
"Does not support value cache of type
{
cache_dtype
}
"
)
value_caches
.
append
(
value_cache
)
value_caches
.
append
(
value_cache
)
return
key_caches
,
value_caches
return
key_caches
,
value_caches
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment