Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
825d8892
"docs/zh_cn/benchmarks.md" did not exist on "5f1366cef0b8d82269f762ada3d23a67205077b5"
Unverified
Commit
825d8892
authored
May 17, 2023
by
Woosuk Kwon
Committed by
GitHub
May 17, 2023
Browse files
Use pytest format for unit tests (#107)
parent
b322fd16
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
37 deletions
+43
-37
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+4
-4
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+14
-14
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+18
-12
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+3
-3
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+4
-4
No files found.
tests/kernels/activation.py
→
tests/kernels/
test_
activation.py
View file @
825d8892
...
...
@@ -10,7 +10,7 @@ def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
@
torch
.
inference_mode
()
def
test
_silu_and_mul
(
def
run
_silu_and_mul
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
...
...
@@ -22,9 +22,9 @@ def test_silu_and_mul(
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
def
test_silu_and_mul
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
13824
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
test
_silu_and_mul
(
num_tokens
,
d
,
dtype
)
run
_silu_and_mul
(
num_tokens
,
d
,
dtype
)
tests/kernels/attention.py
→
tests/kernels/
test_
attention.py
View file @
825d8892
...
...
@@ -8,6 +8,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from
cacheflow
import
attention_ops
MAX_SEQ_LEN
=
4096
TEST_SEED
=
0
def
ref_masked_attention
(
...
...
@@ -155,7 +156,8 @@ def ref_multi_query_cached_kv_attention(
return
ref_output
def
test_single_query_cached_kv_attention
(
@
torch
.
inference_mode
()
def
run_single_query_cached_kv_attention
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
...
...
@@ -223,7 +225,8 @@ def test_single_query_cached_kv_attention(
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_multi_query_kv_attention
(
@
torch
.
inference_mode
()
def
run_multi_query_kv_attention
(
num_seqs
:
int
,
num_heads
:
int
,
head_size
:
int
,
...
...
@@ -264,19 +267,16 @@ def test_multi_query_kv_attention(
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
@
torch
.
inference_mode
()
def
test_attention
(
seed
:
int
)
->
None
:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
# the test fails due to the precision issue. Re-run the test if it fails.
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
def
test_single_query_cached_kv_attention
()
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
block_size
in
[
8
,
16
,
32
,
64
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
test
_single_query_cached_kv_attention
(
run
_single_query_cached_kv_attention
(
num_tokens
=
37
,
num_heads
=
3
,
head_size
=
head_size
,
...
...
@@ -285,17 +285,17 @@ def test_attention(seed: int) -> None:
dtype
=
dtype
,
)
def
test_multi_query_kv_attention
()
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
test
_multi_query_kv_attention
(
run
_multi_query_kv_attention
(
num_seqs
=
5
,
num_heads
=
3
,
head_size
=
head_size
,
dtype
=
dtype
,
)
if
__name__
==
'__main__'
:
test_attention
(
seed
=
0
)
tests/kernels/cache.py
→
tests/kernels/
test_
cache.py
View file @
825d8892
...
...
@@ -5,7 +5,8 @@ import torch
from
cacheflow
import
cache_ops
def
test_copy_blocks
(
@
torch
.
inference_mode
()
def
run_copy_blocks
(
num_mappings
:
int
,
num_layers
:
int
,
num_heads
:
int
,
...
...
@@ -60,7 +61,8 @@ def test_copy_blocks(
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
def
test_reshape_and_cache
(
@
torch
.
inference_mode
()
def
run_reshape_and_cache
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
...
...
@@ -99,7 +101,8 @@ def test_reshape_and_cache(
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
def
test_gather_cached_kv
(
@
torch
.
inference_mode
()
def
run_gather_cached_kv
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
...
...
@@ -140,19 +143,22 @@ def test_gather_cached_kv(
assert
torch
.
allclose
(
value
,
cloned_value
)
@
torch
.
inference_mode
()
def
test_cache
()
->
None
:
def
test_copy_blocks
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
test
_copy_blocks
(
run
_copy_blocks
(
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
dtype
)
test_reshape_and_cache
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
test_gather_cached_kv
(
def
test_reshape_and_cache
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_reshape_and_cache
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
test_cache
()
def
test_gather_cached_kv
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_gather_cached_kv
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
tests/kernels/layernorm.py
→
tests/kernels/
test_
layernorm.py
View file @
825d8892
...
...
@@ -22,7 +22,7 @@ class RefRMSNorm(nn.Module):
@
torch
.
inference_mode
()
def
test
_rms_norm
(
def
run
_rms_norm
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
...
...
@@ -41,13 +41,13 @@ def test_rms_norm(
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
def
test_rms_norm
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
128
,
2048
]:
for
hidden_size
in
[
13
,
64
,
1024
,
5120
]:
print
(
f
'Testing RMS kernel with dtype=
{
dtype
}
, num_tokens='
f
'
{
num_tokens
}
, hidden_size=
{
hidden_size
}
'
)
test
_rms_norm
(
run
_rms_norm
(
num_tokens
=
num_tokens
,
hidden_size
=
hidden_size
,
dtype
=
dtype
,
...
...
tests/kernels/pos_encoding.py
→
tests/kernels/
test_
pos_encoding.py
View file @
825d8892
...
...
@@ -76,7 +76,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
@
torch
.
inference_mode
()
def
test
_rotary_embedding_neox
(
def
run
_rotary_embedding_neox
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
...
...
@@ -128,15 +128,15 @@ def test_rotary_embedding_neox(
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-3
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
def
test_rotary_embedding_neox
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
test
_rotary_embedding_neox
(
run
_rotary_embedding_neox
(
num_tokens
=
2145
,
num_heads
=
5
,
head_size
=
head_size
,
max_position
=
8192
,
rotary_dim
=
int
(
head_size
*
0.25
)
,
rotary_dim
=
head_size
,
dtype
=
dtype
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment