Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
56f738ae
"src/libtorchaudio/sox/effects.cpp" did not exist on "b33c539cfcf3e33af5060b08279aa69400a6314b"
Unverified
Commit
56f738ae
authored
Feb 05, 2024
by
Hongxia Yang
Committed by
GitHub
Feb 05, 2024
Browse files
[ROCm] Fix some kernels failed unit tests (#2498)
parent
72d3a30c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
62 additions
and
12 deletions
+62
-12
tests/kernels/allclose_default.py
tests/kernels/allclose_default.py
+18
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+13
-3
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+17
-5
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+5
-1
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+9
-3
No files found.
tests/kernels/allclose_default.py
0 → 100644
View file @
56f738ae
import
torch
# Reference default values of atol and rtol are from
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
default_atol
=
{
torch
.
float16
:
1e-3
,
torch
.
bfloat16
:
1e-3
,
torch
.
float
:
1e-5
}
default_rtol
=
{
torch
.
float16
:
1e-3
,
torch
.
bfloat16
:
1.6e-2
,
torch
.
float
:
1.3e-6
}
def
get_default_atol
(
output
)
->
float
:
return
default_atol
[
output
.
dtype
]
def
get_default_rtol
(
output
)
->
float
:
return
default_rtol
[
output
.
dtype
]
tests/kernels/test_activation.py
View file @
56f738ae
...
@@ -2,6 +2,7 @@ import pytest
...
@@ -2,6 +2,7 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.layers.activation
import
FastGELU
,
NewGELU
,
SiluAndMul
from
vllm.model_executor.layers.activation
import
FastGELU
,
NewGELU
,
SiluAndMul
from
allclose_default
import
get_default_atol
,
get_default_rtol
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
...
@@ -33,7 +34,10 @@ def test_silu_and_mul(
...
@@ -33,7 +34,10 @@ def test_silu_and_mul(
layer
=
SiluAndMul
()
layer
=
SiluAndMul
()
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -57,7 +61,10 @@ def test_gelu_new(
...
@@ -57,7 +61,10 @@ def test_gelu_new(
layer
=
NewGELU
()
layer
=
NewGELU
()
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -80,4 +87,7 @@ def test_gelu_fast(
...
@@ -80,4 +87,7 @@ def test_gelu_fast(
layer
=
FastGELU
()
layer
=
FastGELU
()
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
tests/kernels/test_attention.py
View file @
56f738ae
...
@@ -8,6 +8,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
...
@@ -8,6 +8,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from
vllm._C
import
ops
,
cache_ops
from
vllm._C
import
ops
,
cache_ops
from
vllm.utils
import
get_max_shared_memory_bytes
from
vllm.utils
import
get_max_shared_memory_bytes
from
vllm.utils
import
is_hip
from
allclose_default
import
get_default_atol
,
get_default_rtol
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
# This will change depending on the compute capability.
...
@@ -17,12 +19,18 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
...
@@ -17,12 +19,18 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# Reduce NUM_BLOCKS when it happens.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS
=
4321
# Arbitrary values for testing
NUM_BLOCKS
=
4321
# Arbitrary values for testing
PARTITION_SIZE
=
512
PARTITION_SIZE
=
512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
if
not
is_hip
()
else
[
torch
.
half
,
torch
.
bfloat16
]
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
if
not
is_hip
()
else
[
64
,
80
,
96
,
112
,
128
]
BLOCK_SIZES
=
[
16
,
32
]
BLOCK_SIZES
=
[
16
,
32
]
USE_ALIBI
=
[
False
,
True
]
USE_ALIBI
=
[
False
,
True
]
KV_CACHE_DTYPE
=
[
"auto"
,
"fp8_e5m2"
]
KV_CACHE_DTYPE
=
[
"auto"
,
"fp8_e5m2"
]
...
@@ -251,9 +259,11 @@ def test_paged_attention(
...
@@ -251,9 +259,11 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
# outputs. Thus, we use a relaxed tolerance for the test.
atol
=
get_default_atol
(
output
)
if
is_hip
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
is_hip
()
else
1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
# so we use a relaxed tolerance for the test.
atol
,
rtol
=
1e-3
,
1e-5
if
kv_cache_dtype
==
"fp8_e5m2"
:
if
kv_cache_dtype
==
"fp8_e5m2"
:
atol
,
rtol
=
1e-2
,
1e-5
atol
,
rtol
=
1e-2
,
1e-5
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
...
@@ -357,4 +367,6 @@ def test_multi_query_kv_attention(
...
@@ -357,4 +367,6 @@ def test_multi_query_kv_attention(
scale
,
scale
,
dtype
,
dtype
,
)
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
atol
=
get_default_atol
(
output
)
if
is_hip
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
is_hip
()
else
1e-5
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
tests/kernels/test_cache.py
View file @
56f738ae
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
typing
import
Tuple
from
typing
import
Tuple
from
vllm._C
import
cache_ops
from
vllm._C
import
cache_ops
from
vllm.utils
import
is_hip
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -14,7 +15,10 @@ NUM_LAYERS = [1] # Arbitrary values for testing
...
@@ -14,7 +15,10 @@ NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
NUM_BLOCKS
=
[
1024
,
3600
]
# Arbitrary values for testing
# reduce the size for ROCm test to avoid HIP OOM
NUM_BLOCKS
=
[
1024
,
36000
]
if
not
is_hip
else
[
1024
,
10000
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
256
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
256
]
# Arbitrary values for testing
SEEDS
=
[
0
]
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
...
...
tests/kernels/test_pos_encoding.py
View file @
56f738ae
...
@@ -2,7 +2,7 @@ from typing import Optional
...
@@ -2,7 +2,7 @@ from typing import Optional
import
pytest
import
pytest
import
torch
import
torch
from
allclose_default
import
get_default_atol
,
get_default_rtol
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
IS_NEOX_STYLE
=
[
True
,
False
]
IS_NEOX_STYLE
=
[
True
,
False
]
...
@@ -64,5 +64,11 @@ def test_rotary_embedding(
...
@@ -64,5 +64,11 @@ def test_rotary_embedding(
ref_query
,
ref_key
=
rope
.
_forward
(
positions
,
query
,
key
)
ref_query
,
ref_key
=
rope
.
_forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
# Compare the results.
# Compare the results.
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_query
,
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-5
,
rtol
=
1e-5
)
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment