Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
825d8892
Unverified
Commit
825d8892
authored
May 17, 2023
by
Woosuk Kwon
Committed by
GitHub
May 17, 2023
Browse files
Use pytest format for unit tests (#107)
parent
b322fd16
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
37 deletions
+43
-37
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+4
-4
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+14
-14
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+18
-12
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+3
-3
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+4
-4
No files found.
tests/kernels/activation.py
→
tests/kernels/
test_
activation.py
View file @
825d8892
...
@@ -10,7 +10,7 @@ def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
...
@@ -10,7 +10,7 @@ def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test
_silu_and_mul
(
def
run
_silu_and_mul
(
num_tokens
:
int
,
num_tokens
:
int
,
d
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
...
@@ -22,9 +22,9 @@ def test_silu_and_mul(
...
@@ -22,9 +22,9 @@ def test_silu_and_mul(
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
def
test_silu_and_mul
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
13824
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
test
_silu_and_mul
(
num_tokens
,
d
,
dtype
)
run
_silu_and_mul
(
num_tokens
,
d
,
dtype
)
tests/kernels/attention.py
→
tests/kernels/
test_
attention.py
View file @
825d8892
...
@@ -8,6 +8,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
...
@@ -8,6 +8,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from
cacheflow
import
attention_ops
from
cacheflow
import
attention_ops
MAX_SEQ_LEN
=
4096
MAX_SEQ_LEN
=
4096
TEST_SEED
=
0
def
ref_masked_attention
(
def
ref_masked_attention
(
...
@@ -155,7 +156,8 @@ def ref_multi_query_cached_kv_attention(
...
@@ -155,7 +156,8 @@ def ref_multi_query_cached_kv_attention(
return
ref_output
return
ref_output
def
test_single_query_cached_kv_attention
(
@
torch
.
inference_mode
()
def
run_single_query_cached_kv_attention
(
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -223,7 +225,8 @@ def test_single_query_cached_kv_attention(
...
@@ -223,7 +225,8 @@ def test_single_query_cached_kv_attention(
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_multi_query_kv_attention
(
@
torch
.
inference_mode
()
def
run_multi_query_kv_attention
(
num_seqs
:
int
,
num_seqs
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -264,19 +267,16 @@ def test_multi_query_kv_attention(
...
@@ -264,19 +267,16 @@ def test_multi_query_kv_attention(
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
@
torch
.
inference_mode
()
def
test_single_query_cached_kv_attention
()
->
None
:
def
test_attention
(
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
# the test fails due to the precision issue. Re-run the test if it fails.
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
block_size
in
[
8
,
16
,
32
,
64
]:
for
block_size
in
[
8
,
16
,
32
,
64
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
f
'head_size=
{
head_size
}
'
)
test
_single_query_cached_kv_attention
(
run
_single_query_cached_kv_attention
(
num_tokens
=
37
,
num_tokens
=
37
,
num_heads
=
3
,
num_heads
=
3
,
head_size
=
head_size
,
head_size
=
head_size
,
...
@@ -285,17 +285,17 @@ def test_attention(seed: int) -> None:
...
@@ -285,17 +285,17 @@ def test_attention(seed: int) -> None:
dtype
=
dtype
,
dtype
=
dtype
,
)
)
def
test_multi_query_kv_attention
()
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
f
'head_size=
{
head_size
}
'
)
test
_multi_query_kv_attention
(
run
_multi_query_kv_attention
(
num_seqs
=
5
,
num_seqs
=
5
,
num_heads
=
3
,
num_heads
=
3
,
head_size
=
head_size
,
head_size
=
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
)
)
if
__name__
==
'__main__'
:
test_attention
(
seed
=
0
)
tests/kernels/cache.py
→
tests/kernels/
test_
cache.py
View file @
825d8892
...
@@ -5,7 +5,8 @@ import torch
...
@@ -5,7 +5,8 @@ import torch
from
cacheflow
import
cache_ops
from
cacheflow
import
cache_ops
def
test_copy_blocks
(
@
torch
.
inference_mode
()
def
run_copy_blocks
(
num_mappings
:
int
,
num_mappings
:
int
,
num_layers
:
int
,
num_layers
:
int
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -60,7 +61,8 @@ def test_copy_blocks(
...
@@ -60,7 +61,8 @@ def test_copy_blocks(
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
def
test_reshape_and_cache
(
@
torch
.
inference_mode
()
def
run_reshape_and_cache
(
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -99,7 +101,8 @@ def test_reshape_and_cache(
...
@@ -99,7 +101,8 @@ def test_reshape_and_cache(
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
def
test_gather_cached_kv
(
@
torch
.
inference_mode
()
def
run_gather_cached_kv
(
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -140,19 +143,22 @@ def test_gather_cached_kv(
...
@@ -140,19 +143,22 @@ def test_gather_cached_kv(
assert
torch
.
allclose
(
value
,
cloned_value
)
assert
torch
.
allclose
(
value
,
cloned_value
)
@
torch
.
inference_mode
()
def
test_copy_blocks
()
->
None
:
def
test_cache
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
test
_copy_blocks
(
run
_copy_blocks
(
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
dtype
)
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
dtype
)
test_reshape_and_cache
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
def
test_reshape_and_cache
()
->
None
:
test_gather_cached_kv
(
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_reshape_and_cache
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
def
test_gather_cached_kv
()
->
None
:
test_cache
()
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_gather_cached_kv
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
tests/kernels/layernorm.py
→
tests/kernels/
test_
layernorm.py
View file @
825d8892
...
@@ -22,7 +22,7 @@ class RefRMSNorm(nn.Module):
...
@@ -22,7 +22,7 @@ class RefRMSNorm(nn.Module):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test
_rms_norm
(
def
run
_rms_norm
(
num_tokens
:
int
,
num_tokens
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
...
@@ -41,13 +41,13 @@ def test_rms_norm(
...
@@ -41,13 +41,13 @@ def test_rms_norm(
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
def
test_rms_norm
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
128
,
2048
]:
for
num_tokens
in
[
7
,
128
,
2048
]:
for
hidden_size
in
[
13
,
64
,
1024
,
5120
]:
for
hidden_size
in
[
13
,
64
,
1024
,
5120
]:
print
(
f
'Testing RMS kernel with dtype=
{
dtype
}
, num_tokens='
print
(
f
'Testing RMS kernel with dtype=
{
dtype
}
, num_tokens='
f
'
{
num_tokens
}
, hidden_size=
{
hidden_size
}
'
)
f
'
{
num_tokens
}
, hidden_size=
{
hidden_size
}
'
)
test
_rms_norm
(
run
_rms_norm
(
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
dtype
=
dtype
,
dtype
=
dtype
,
...
...
tests/kernels/pos_encoding.py
→
tests/kernels/
test_
pos_encoding.py
View file @
825d8892
...
@@ -76,7 +76,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
...
@@ -76,7 +76,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test
_rotary_embedding_neox
(
def
run
_rotary_embedding_neox
(
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -128,15 +128,15 @@ def test_rotary_embedding_neox(
...
@@ -128,15 +128,15 @@ def test_rotary_embedding_neox(
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-3
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
def
test_rotary_embedding_neox
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
test
_rotary_embedding_neox
(
run
_rotary_embedding_neox
(
num_tokens
=
2145
,
num_tokens
=
2145
,
num_heads
=
5
,
num_heads
=
5
,
head_size
=
head_size
,
head_size
=
head_size
,
max_position
=
8192
,
max_position
=
8192
,
rotary_dim
=
int
(
head_size
*
0.25
)
,
rotary_dim
=
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment