Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
fbd80ad4
Unverified
Commit
fbd80ad4
authored
Sep 06, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 05, 2023
Browse files
Clean up kernel unit tests (#938)
parent
22379d55
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
364 additions
and
399 deletions
+364
-399
tests/kernels/conftest.py
tests/kernels/conftest.py
+43
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+31
-28
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+159
-198
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+80
-130
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+24
-22
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+27
-21
No files found.
tests/kernels/conftest.py
0 → 100644
View file @
fbd80ad4
from
typing
import
List
,
Tuple
import
pytest
import
torch
def
create_kv_caches
(
num_blocks
:
int
,
block_size
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
head_size
**-
0.5
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
=
[]
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
key_cache
.
uniform_
(
-
scale
,
scale
)
key_caches
.
append
(
key_cache
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
=
[]
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
value_cache
.
uniform_
(
-
scale
,
scale
)
value_caches
.
append
(
value_cache
)
return
key_caches
,
value_caches
@
pytest
.
fixture
()
def
kv_cache_factory
():
return
create_kv_caches
tests/kernels/test_activation.py
View file @
fbd80ad4
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers.activations
import
get_activation
from
transformers.activations
import
get_activation
from
vllm
import
activation_ops
from
vllm
import
activation_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
D
=
[
512
,
4096
,
5120
,
13824
]
# Arbitrary values for testing
SEEDS
=
[
0
]
def
ref_silu_and_mul
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
ref_silu_and_mul
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
,
x2
=
x
.
chunk
(
chunks
=
2
,
dim
=
1
)
x1
,
x2
=
x
.
chunk
(
chunks
=
2
,
dim
=
1
)
return
F
.
silu
(
x1
)
*
x2
return
F
.
silu
(
x1
)
*
x2
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run
_silu_and_mul
(
def
test
_silu_and_mul
(
num_tokens
:
int
,
num_tokens
:
int
,
d
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
'cuda'
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
silu_and_mul
(
out
,
x
)
activation_ops
.
silu_and_mul
(
out
,
x
)
...
@@ -22,20 +36,19 @@ def run_silu_and_mul(
...
@@ -22,20 +36,19 @@ def run_silu_and_mul(
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_silu_and_mul
()
->
None
:
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
for
num_tokens
in
[
7
,
83
,
2048
]:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
for
d
in
[
512
,
4096
,
5120
,
13824
]:
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_silu_and_mul
(
num_tokens
,
d
,
dtype
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run
_gelu_new
(
def
test
_gelu_new
(
num_tokens
:
int
,
num_tokens
:
int
,
d
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
gelu_new
(
out
,
x
)
activation_ops
.
gelu_new
(
out
,
x
)
...
@@ -43,30 +56,20 @@ def run_gelu_new(
...
@@ -43,30 +56,20 @@ def run_gelu_new(
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_gelu_new
()
->
None
:
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
for
num_tokens
in
[
7
,
83
,
2048
]:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
for
d
in
[
512
,
4096
,
5120
,
13824
]:
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
def
test_gelu_fast
(
run_gelu_new
(
num_tokens
,
d
,
dtype
)
@
torch
.
inference_mode
()
def
run_gelu_fast
(
num_tokens
:
int
,
num_tokens
:
int
,
d
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
gelu_fast
(
out
,
x
)
activation_ops
.
gelu_fast
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_fast"
)(
x
)
ref_out
=
get_activation
(
"gelu_fast"
)(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_gelu_fast
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_gelu_fast
(
num_tokens
,
d
,
dtype
)
tests/kernels/test_attention.py
View file @
fbd80ad4
import
random
import
random
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
torch
import
torch
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm
import
attention_ops
from
vllm
import
attention_ops
MAX_SEQ_LEN
=
4096
MAX_SEQ_LEN
=
8192
TEST_SEED
=
0
NUM_BLOCKS
=
128
# Arbitrary values for testing
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
1
,
3
,
7
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
USE_ALIBI
=
[
False
]
# TODO(woosuk): Add USE_ALIBI=True
SEEDS
=
[
0
]
def
ref_masked_attention
(
def
ref_masked_attention
(
...
@@ -18,29 +28,34 @@ def ref_masked_attention(
...
@@ -18,29 +28,34 @@ def ref_masked_attention(
scale
:
float
,
scale
:
float
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
query
=
query
*
scale
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
attn
=
torch
.
einsum
(
'qhd,khd->hqk'
,
query
,
key
)
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
attn
=
attn
+
attn_mask
attn
_weights
=
attn_weights
+
attn_mask
.
float
()
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
)
attn
_weights
=
torch
.
softmax
(
attn
_weights
,
dim
=-
1
)
.
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
'
hqk,khd->qhd
'
,
attn
,
value
)
out
=
torch
.
einsum
(
"
hqk,khd->qhd
"
,
attn
_weights
,
value
)
return
out
return
out
def
ref_single_query_cached_kv_attention
(
def
ref_single_query_cached_kv_attention
(
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
num_queries_per_kv
:
int
,
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
None
:
)
->
None
:
num_heads
=
value_cache
.
shape
[
1
]
num_query_heads
=
query
.
shape
[
1
]
num_kv_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
2
]
head_size
=
value_cache
.
shape
[
2
]
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
query
.
shape
[
0
]
num_input_tokens
=
query
.
shape
[
0
]
block_tables
=
block_tables
.
cpu
().
tolist
()
for
i
in
range
(
num_input_tokens
):
context_lens
=
context_lens
.
cpu
().
tolist
()
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
q
=
query
[
i
].
unsqueeze
(
0
)
block_table
=
block_tables
[
i
]
block_table
=
block_tables
[
i
]
context_len
=
int
(
context_lens
[
i
])
context_len
=
int
(
context_lens
[
i
])
...
@@ -52,170 +67,96 @@ def ref_single_query_cached_kv_attention(
...
@@ -52,170 +67,96 @@ def ref_single_query_cached_kv_attention(
block_offset
=
j
%
block_size
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_heads
,
head_size
)
k
=
k
.
reshape
(
num_
kv_
heads
,
head_size
)
keys
.
append
(
k
)
keys
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values
.
append
(
v
)
values
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
if
num_queries_per_kv
>
1
:
scale
=
1.0
/
(
head_size
**
0.5
)
# Handle MQA and GQA
out
=
ref_masked_attention
(
q
,
keys
,
values
,
scale
)
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
out
=
out
.
view
(
num_heads
,
head_size
)
values
=
torch
.
repeat_interleave
(
values
,
num_queries_per_kv
,
dim
=
1
)
alibi_bias
=
None
if
alibi_slopes
is
not
None
:
# Create the ALiBi bias used in the paged attention kernel.
position_ids
=
torch
.
arange
(
context_len
,
device
=
"cuda"
).
int
()
alibi_bias
=
(
context_len
-
position_ids
).
float
()
alibi_bias
=
alibi_slopes
.
view
(
-
1
,
1
,
1
)
*
alibi_bias
.
view
(
1
,
1
,
-
1
)
out
=
ref_masked_attention
(
q
,
keys
,
values
,
scale
,
alibi_bias
)
out
=
out
.
view
(
num_query_heads
,
head_size
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
def
ref_multi_query_kv_attention
(
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
cu_seq_lens
:
List
[
int
],
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
query
:
torch
.
Tensor
,
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
key
:
torch
.
Tensor
,
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
USE_ALIBI
)
value
:
torch
.
Tensor
,
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
dtype
:
torch
.
dtype
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
)
->
torch
.
Tensor
:
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
head_size
=
query
.
shape
[
-
1
]
scale
=
1.0
/
(
head_size
**
0.5
)
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
def
ref_multi_query_cached_kv_attention
(
cu_query_lens
:
List
[
int
],
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
2
]
block_size
=
value_cache
.
shape
[
3
]
scale
=
1.0
/
(
head_size
**
0.5
)
num_queries
=
len
(
cu_query_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_queries
):
start_idx
=
cu_query_lens
[
i
]
end_idx
=
cu_query_lens
[
i
+
1
]
query_len
=
end_idx
-
start_idx
context_len
=
int
(
context_lens
[
i
])
block_table
=
block_tables
[
i
]
# Create attention mask
attn_mask
=
torch
.
triu
(
torch
.
ones
(
query_len
,
context_len
),
diagonal
=
context_len
-
query_len
+
1
)
*
-
1e5
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
keys
=
[]
values
=
[]
for
j
in
range
(
context_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_heads
,
head_size
)
keys
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
keys
,
values
,
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_single_query_cached_kv_attention
(
def
test_single_query_cached_kv_attention
(
num_tokens
:
int
,
kv_cache_factory
,
num_heads
:
int
,
num_seqs
:
int
,
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
num_kv_heads
:
int
=
None
,
seed
:
int
,
)
->
None
:
)
->
None
:
qkv
=
torch
.
empty
(
num_tokens
,
random
.
seed
(
seed
)
3
,
torch
.
random
.
manual_seed
(
seed
)
num_heads
,
torch
.
cuda
.
manual_seed
(
seed
)
head_size
,
dtype
=
dtype
,
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
device
=
'cuda'
)
num_query_heads
,
num_kv_heads
=
num_heads
qkv
.
uniform_
(
-
1e-3
,
1e-3
)
query
=
torch
.
empty
(
num_seqs
,
query
,
_
,
_
=
qkv
.
unbind
(
dim
=
1
)
num_query_heads
,
head_size
,
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
dtype
=
dtype
,
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
device
=
"cuda"
)
key_cache
=
torch
.
empty
(
size
=
(
num_blocks
,
*
key_block_shape
),
query
.
uniform_
(
-
scale
,
scale
)
dtype
=
dtype
,
device
=
'cuda'
)
assert
num_query_heads
%
num_kv_heads
==
0
key_cache
.
uniform_
(
-
1e-3
,
1e-3
)
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
value_block_shape
=
(
num_heads
,
head_size
,
block_size
)
head_mapping
=
torch
.
repeat_interleave
(
value_cache
=
torch
.
empty
(
size
=
(
num_blocks
,
*
value_block_shape
),
torch
.
arange
(
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
dtype
=
dtype
,
num_queries_per_kv
)
device
=
'cuda'
)
alibi_slopes
=
None
value_cache
.
uniform_
(
-
1e-3
,
1e-3
)
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_tokens
)]
dtype
=
torch
.
float
,
device
=
"cuda"
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
max_context_len
=
max
(
context_lens
)
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'
cuda
'
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"
cuda
"
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
block_tables
=
[]
for
_
in
range
(
num_
token
s
):
for
_
in
range
(
num_
seq
s
):
block_table
=
[
block_table
=
[
random
.
randint
(
0
,
num_blocks
-
1
)
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
]
block_tables
.
append
(
block_table
)
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
head_mapping
=
torch
.
arange
(
num_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
num_kv_heads
,
head_size
,
dtype
,
assert
num_heads
%
num_kv_heads
==
0
seed
)
num_queries_per_kv
=
num_heads
//
num_kv_heads
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
head_mapping
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
num_queries_per_kv
)
output
=
torch
.
empty
(
num_tokens
,
# Call the paged attention kernel.
num_heads
,
output
=
torch
.
empty_like
(
query
)
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
attention_ops
.
single_query_cached_kv_attention
(
attention_ops
.
single_query_cached_kv_attention
(
output
,
output
,
query
,
query
,
...
@@ -227,45 +168,98 @@ def run_single_query_cached_kv_attention(
...
@@ -227,45 +168,98 @@ def run_single_query_cached_kv_attention(
context_lens
,
context_lens
,
block_size
,
block_size
,
max_context_len
,
max_context_len
,
None
,
# ALiBi
slopes
.
alibi_
slopes
,
)
)
# Run the reference implementation.
ref_output
=
torch
.
empty_like
(
query
)
ref_output
=
torch
.
empty_like
(
query
)
ref_single_query_cached_kv_attention
(
ref_single_query_cached_kv_attention
(
ref_output
,
ref_output
,
query
,
query
,
num_queries_per_kv
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
block_tables
,
block_tables
,
context_lens
,
context_lens
,
scale
,
alibi_slopes
,
)
)
# NOTE(woosuk): Due to the difference in the data types the two
#
implementations use for attention softmax logits and accumulation,
#
NOTE(woosuk): Due to the kernel-level differences in the two
# there is a small difference in the
final outputs.
#
implementations,
there is a small
numerical
difference in the
two
#
We should
use a relaxed tolerance for the test.
#
outputs. Thus, we
use a relaxed tolerance for the test.
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
List
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
"cuda"
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run
_multi_query_kv_attention
(
def
test
_multi_query_kv_attention
(
num_seqs
:
int
,
num_seqs
:
int
,
num_heads
:
int
,
num_heads
:
Tuple
[
int
,
int
]
,
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
num_tokens
=
sum
(
seq_lens
)
num_tokens
=
sum
(
seq_lens
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_query_heads
,
num_kv_heads
=
num_heads
qkv
=
torch
.
empty
(
num_tokens
,
qkv
=
torch
.
empty
(
num_tokens
,
3
,
num_query_heads
+
2
*
num_kv_heads
,
num_heads
,
head_size
,
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
'cuda'
)
device
=
"cuda"
)
qkv
.
uniform_
(
-
1e-3
,
1e-3
)
qkv
.
uniform_
(
-
scale
,
scale
)
query
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
query
,
key
,
value
=
qkv
.
split
(
[
num_query_heads
,
num_kv_heads
,
num_kv_heads
],
dim
=
1
)
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
output
=
xops
.
memory_efficient_attention_forward
(
output
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
query
.
unsqueeze
(
0
),
...
@@ -285,40 +279,7 @@ def run_multi_query_kv_attention(
...
@@ -285,40 +279,7 @@ def run_multi_query_kv_attention(
query
,
query
,
key
,
key
,
value
,
value
,
scale
,
dtype
,
dtype
,
)
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_single_query_cached_kv_attention
()
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
block_size
in
[
8
,
16
,
32
]:
for
head_size
in
[
64
,
80
,
96
,
112
,
128
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
run_single_query_cached_kv_attention
(
num_tokens
=
37
,
num_heads
=
3
,
head_size
=
head_size
,
block_size
=
block_size
,
num_blocks
=
1024
,
dtype
=
dtype
,
)
def
test_multi_query_kv_attention
()
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
64
,
80
,
96
,
112
,
128
,
256
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
run_multi_query_kv_attention
(
num_seqs
=
5
,
num_heads
=
3
,
head_size
=
head_size
,
dtype
=
dtype
,
)
tests/kernels/test_cache.py
View file @
fbd80ad4
import
random
import
random
import
pytest
import
torch
import
torch
from
vllm
import
cache_ops
from
vllm
import
cache_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
NUM_LAYERS
=
[
5
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
NUM_BLOCKS
=
[
1024
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
32
,
256
]
# Arbitrary values for testing
SEEDS
=
[
0
]
@
pytest
.
mark
.
parametrize
(
"num_mappings"
,
NUM_MAPPINGS
)
@
pytest
.
mark
.
parametrize
(
"num_layers"
,
NUM_LAYERS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_copy_blocks
(
def
test_copy_blocks
(
kv_cache_factory
,
num_mappings
:
int
,
num_mappings
:
int
,
num_layers
:
int
,
num_layers
:
int
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -14,48 +34,43 @@ def run_copy_blocks(
...
@@ -14,48 +34,43 @@ def run_copy_blocks(
block_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
)
->
None
:
# Generate random block mappings.
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert
2
*
num_mappings
<=
num_blocks
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
num_mappings
)
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
block_mapping
=
{
src
:
[
dst
]
for
src
,
dst
in
zip
(
src_blocks
,
dst_blocks
)}
block_mapping
=
{}
for
i
in
range
(
num_mappings
):
# Create the KV cache.
src
=
src_blocks
[
i
]
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
dst1
=
dst_blocks
[
2
*
i
]
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
dst2
=
dst_blocks
[
2
*
i
+
1
]
key_caches
=
[]
block_mapping
[
src
]
=
[
dst1
,
dst2
]
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
randn
(
size
=
key_cache_shape
,
# Create the KV caches.
dtype
=
dtype
,
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
device
=
'cuda'
)
num_layers
,
num_heads
,
key_caches
.
append
(
key_cache
)
head_size
,
dtype
,
seed
)
cloned_key_caches
=
[]
for
key_cache
in
key_caches
:
# Clone the KV caches.
cloned_key_caches
.
append
(
key_cache
.
clone
())
cloned_key_caches
=
[
key_cache
.
clone
()
for
key_cache
in
key_caches
]
cloned_value_caches
=
[
value_cache
.
clone
()
for
value_cache
in
value_caches
]
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
=
[]
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
randn
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
value_caches
.
append
(
value_cache
)
cloned_value_caches
=
[]
for
value_cache
in
value_caches
:
cloned_value_caches
.
append
(
value_cache
.
clone
())
# Call the copy blocks kernel.
# Call the copy blocks kernel.
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
# Reference implementation.
# R
un the r
eference implementation.
for
src
,
dsts
in
block_mapping
.
items
():
for
src
,
dsts
in
block_mapping
.
items
():
for
dst
in
dsts
:
for
dst
in
dsts
:
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_caches
):
cloned_key_cache
[
dst
]
=
cloned_key_cache
[
src
]
cloned_key_cache
[
dst
]
=
cloned_key_cache
[
src
]
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_caches
):
cloned_value_cache
[
dst
]
=
cloned_value_cache
[
src
]
cloned_value_cache
[
dst
]
=
cloned_value_cache
[
src
]
# Compare the results.
# Compare the results.
...
@@ -66,15 +81,29 @@ def run_copy_blocks(
...
@@ -66,15 +81,29 @@ def run_copy_blocks(
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_reshape_and_cache
(
def
test_reshape_and_cache
(
kv_cache_factory
,
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
block_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
...
@@ -87,110 +116,31 @@ def run_reshape_and_cache(
...
@@ -87,110 +116,31 @@ def run_reshape_and_cache(
device
=
'cuda'
)
device
=
'cuda'
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
# Create the KV caches.
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
1
,
key_cache
=
torch
.
randn
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
num_heads
,
head_size
,
dtype
,
cloned_key_cache
=
key_cache
.
clone
()
seed
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
# Clone the KV caches.
value_cache
=
torch
.
randn
(
size
=
value_cache_shape
,
cloned_key_cache
=
key_cache
.
clone
()
dtype
=
dtype
,
device
=
'cuda'
)
cloned_value_cache
=
value_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Call the reshape_and_cache kernel.
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
slot_mapping
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'floor'
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
reshaped_key
=
key
.
reshape
(
num_tokens
,
num_heads
,
head_size
//
x
,
x
)
block_idx
=
block_indicies
[
i
]
block_idx
=
torch
.
div
(
slot_mapping
[
i
],
block_offset
=
block_offsets
[
i
]
block_size
,
rounding_mode
=
'floor'
)
block_offset
=
slot_mapping
[
i
]
%
block_size
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
torch
.
inference_mode
()
def
run_gather_cached_kv
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
qkv_clone
=
qkv
.
clone
()
_
,
cloned_key
,
cloned_value
=
qkv_clone
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
randn
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
cache_ops
.
gather_cached_kv
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
# Reference implementation.
for
i
in
range
(
num_tokens
):
reshaped_key
=
cloned_key
.
reshape
(
num_tokens
,
num_heads
,
head_size
//
x
,
x
)
block_idx
=
torch
.
div
(
slot_mapping
[
i
],
block_size
,
rounding_mode
=
'floor'
)
block_offset
=
slot_mapping
[
i
]
%
block_size
reshaped_key
[
i
]
=
key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
cloned_value
[
i
]
=
value_cache
[
block_idx
,
:,
:,
block_offset
]
assert
torch
.
allclose
(
key
,
cloned_key
)
assert
torch
.
allclose
(
value
,
cloned_value
)
def
test_copy_blocks
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_copy_blocks
(
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
dtype
)
def
test_reshape_and_cache
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_reshape_and_cache
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
def
test_gather_cached_kv
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_gather_cached_kv
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
tests/kernels/test_layernorm.py
View file @
fbd80ad4
import
pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
import
layernorm_ops
from
vllm
import
layernorm_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
67
,
768
,
2048
,
5120
,
8192
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
SEEDS
=
[
0
]
class
RefRMSNorm
(
nn
.
Module
):
class
RefRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
super
().
__init__
()
weight
=
torch
.
empty
(
hidden_size
)
weight
=
torch
.
empty
(
hidden_size
)
weight
.
uniform_
(
-
1e-3
,
1e-3
)
weight
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
input_dtype
=
hidden_states
.
dtype
keepdim
=
True
)
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
self
.
variance_epsilon
)
if
self
.
weight
.
dtype
in
[
torch
.
half
,
torch
.
float16
,
torch
.
bfloat16
]:
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
hidden_states
=
hidden_states
.
to
(
self
.
weight
.
dtype
)
return
self
.
weight
*
hidden_states
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run
_rms_norm
(
def
test
_rms_norm
(
num_tokens
:
int
,
num_tokens
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
)
->
None
:
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
'cuda'
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
float
(
hidden_size
**-
0.5
)
x
=
torch
.
empty
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x
.
uniform_
(
-
scale
,
scale
)
ref
=
RefRMSNorm
(
hidden_size
).
to
(
dtype
).
cuda
()
ref
=
RefRMSNorm
(
hidden_size
).
to
(
dtype
).
cuda
()
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
...
@@ -40,17 +55,4 @@ def run_rms_norm(
...
@@ -40,17 +55,4 @@ def run_rms_norm(
ref
.
variance_epsilon
,
ref
.
variance_epsilon
,
)
)
ref_out
=
ref
(
x
)
ref_out
=
ref
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-5
)
def
test_rms_norm
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
128
,
2048
]:
for
hidden_size
in
[
13
,
64
,
1024
,
5120
]:
print
(
f
'Testing RMS kernel with dtype=
{
dtype
}
, num_tokens='
f
'
{
num_tokens
}
, hidden_size=
{
hidden_size
}
'
)
run_rms_norm
(
num_tokens
=
num_tokens
,
hidden_size
=
hidden_size
,
dtype
=
dtype
,
)
tests/kernels/test_pos_encoding.py
View file @
fbd80ad4
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
import
pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm
import
pos_encoding_ops
from
vllm
import
pos_encoding_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
NUM_HEADS
=
[
7
,
12
,
40
,
52
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
SEEDS
=
[
0
]
def
rotate_half
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rotate_half
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
...
@@ -74,16 +82,28 @@ class RefRotaryEmbeddingNeox(nn.Module):
...
@@ -74,16 +82,28 @@ class RefRotaryEmbeddingNeox(nn.Module):
return
query
,
key
return
query
,
key
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run
_rotary_embedding_neox
(
def
test
_rotary_embedding_neox
(
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
max_position
:
int
,
rotary_dim
:
Optional
[
int
],
rotary_dim
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
if
rotary_dim
is
None
:
rotary_dim
=
head_size
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
num_tokens
,
),
device
=
'cuda'
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
num_tokens
,
),
device
=
'cuda'
)
query
=
torch
.
randn
(
num_tokens
,
query
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
num_heads
*
head_size
,
...
@@ -97,7 +117,7 @@ def run_rotary_embedding_neox(
...
@@ -97,7 +117,7 @@ def run_rotary_embedding_neox(
# Create the rotary embedding.
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'
i,j -> ij
'
,
t
,
inv_freq
.
float
())
freqs
=
torch
.
einsum
(
"
i,j -> ij
"
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
sin
=
freqs
.
sin
()
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
...
@@ -129,19 +149,5 @@ def run_rotary_embedding_neox(
...
@@ -129,19 +149,5 @@ def run_rotary_embedding_neox(
ref_key
=
ref_key
.
view
(
num_tokens
,
num_heads
*
head_size
)
ref_key
=
ref_key
.
view
(
num_tokens
,
num_heads
*
head_size
)
# Compare the results.
# Compare the results.
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_rotary_embedding_neox
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
run_rotary_embedding_neox
(
num_tokens
=
2145
,
num_heads
=
5
,
head_size
=
head_size
,
max_position
=
8192
,
rotary_dim
=
head_size
,
dtype
=
dtype
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment