Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
fbd80ad4
Unverified
Commit
fbd80ad4
authored
Sep 06, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 05, 2023
Browse files
Clean up kernel unit tests (#938)
parent
22379d55
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
364 additions
and
399 deletions
+364
-399
tests/kernels/conftest.py
tests/kernels/conftest.py
+43
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+31
-28
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+159
-198
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+80
-130
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+24
-22
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+27
-21
No files found.
tests/kernels/conftest.py
0 → 100644
View file @
fbd80ad4
from
typing
import
List
,
Tuple
import
pytest
import
torch
def
create_kv_caches
(
num_blocks
:
int
,
block_size
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
head_size
**-
0.5
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
=
[]
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
key_cache
.
uniform_
(
-
scale
,
scale
)
key_caches
.
append
(
key_cache
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
=
[]
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
value_cache
.
uniform_
(
-
scale
,
scale
)
value_caches
.
append
(
value_cache
)
return
key_caches
,
value_caches
@
pytest
.
fixture
()
def
kv_cache_factory
():
return
create_kv_caches
tests/kernels/test_activation.py
View file @
fbd80ad4
import
pytest
import
torch
import
torch.nn.functional
as
F
from
transformers.activations
import
get_activation
from
vllm
import
activation_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
D
=
[
512
,
4096
,
5120
,
13824
]
# Arbitrary values for testing
SEEDS
=
[
0
]
def
ref_silu_and_mul
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
,
x2
=
x
.
chunk
(
chunks
=
2
,
dim
=
1
)
return
F
.
silu
(
x1
)
*
x2
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
run
_silu_and_mul
(
def
test
_silu_and_mul
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
silu_and_mul
(
out
,
x
)
...
...
@@ -22,20 +36,19 @@ def run_silu_and_mul(
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_silu_and_mul
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_silu_and_mul
(
num_tokens
,
d
,
dtype
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
run
_gelu_new
(
def
test
_gelu_new
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
gelu_new
(
out
,
x
)
...
...
@@ -43,30 +56,20 @@ def run_gelu_new(
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_gelu_new
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_gelu_new
(
num_tokens
,
d
,
dtype
)
@
torch
.
inference_mode
()
def
run_gelu_fast
(
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
def
test_gelu_fast
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'cuda'
)
activation_ops
.
gelu_fast
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_fast"
)(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_gelu_fast
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
83
,
2048
]:
for
d
in
[
512
,
4096
,
5120
,
13824
]:
print
(
f
'Testing dtype=
{
dtype
}
, num_tokens=
{
num_tokens
}
, d=
{
d
}
'
)
run_gelu_fast
(
num_tokens
,
d
,
dtype
)
tests/kernels/test_attention.py
View file @
fbd80ad4
import
random
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
torch
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm
import
attention_ops
MAX_SEQ_LEN
=
4096
TEST_SEED
=
0
MAX_SEQ_LEN
=
8192
NUM_BLOCKS
=
128
# Arbitrary values for testing
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_GEN_SEQS
=
[
7
]
# Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
1
,
3
,
7
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
USE_ALIBI
=
[
False
]
# TODO(woosuk): Add USE_ALIBI=True
SEEDS
=
[
0
]
def
ref_masked_attention
(
...
...
@@ -18,29 +28,34 @@ def ref_masked_attention(
scale
:
float
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
query
=
query
*
scale
attn
=
torch
.
einsum
(
'qhd,khd->hqk'
,
query
,
key
)
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
if
attn_mask
is
not
None
:
attn
=
attn
+
attn_mask
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
)
out
=
torch
.
einsum
(
'
hqk,khd->qhd
'
,
attn
,
value
)
attn
_weights
=
attn_weights
+
attn_mask
.
float
()
attn
_weights
=
torch
.
softmax
(
attn
_weights
,
dim
=-
1
)
.
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"
hqk,khd->qhd
"
,
attn
_weights
,
value
)
return
out
def
ref_single_query_cached_kv_attention
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
num_queries_per_kv
:
int
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
None
:
num_heads
=
value_cache
.
shape
[
1
]
num_query_heads
=
query
.
shape
[
1
]
num_kv_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
2
]
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
query
.
shape
[
0
]
num_input_tokens
=
query
.
shape
[
0
]
for
i
in
range
(
num_input_tokens
):
block_tables
=
block_tables
.
cpu
().
tolist
()
context_lens
=
context_lens
.
cpu
().
tolist
()
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
block_table
=
block_tables
[
i
]
context_len
=
int
(
context_lens
[
i
])
...
...
@@ -52,170 +67,96 @@ def ref_single_query_cached_kv_attention(
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_heads
,
head_size
)
k
=
k
.
reshape
(
num_
kv_
heads
,
head_size
)
keys
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
scale
=
1.0
/
(
head_size
**
0.5
)
out
=
ref_masked_attention
(
q
,
keys
,
values
,
scale
)
out
=
out
.
view
(
num_heads
,
head_size
)
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
values
=
torch
.
repeat_interleave
(
values
,
num_queries_per_kv
,
dim
=
1
)
alibi_bias
=
None
if
alibi_slopes
is
not
None
:
# Create the ALiBi bias used in the paged attention kernel.
position_ids
=
torch
.
arange
(
context_len
,
device
=
"cuda"
).
int
()
alibi_bias
=
(
context_len
-
position_ids
).
float
()
alibi_bias
=
alibi_slopes
.
view
(
-
1
,
1
,
1
)
*
alibi_bias
.
view
(
1
,
1
,
-
1
)
out
=
ref_masked_attention
(
q
,
keys
,
values
,
scale
,
alibi_bias
)
out
=
out
.
view
(
num_query_heads
,
head_size
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
List
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
head_size
=
query
.
shape
[
-
1
]
scale
=
1.0
/
(
head_size
**
0.5
)
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
def
ref_multi_query_cached_kv_attention
(
cu_query_lens
:
List
[
int
],
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
2
]
block_size
=
value_cache
.
shape
[
3
]
scale
=
1.0
/
(
head_size
**
0.5
)
num_queries
=
len
(
cu_query_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_queries
):
start_idx
=
cu_query_lens
[
i
]
end_idx
=
cu_query_lens
[
i
+
1
]
query_len
=
end_idx
-
start_idx
context_len
=
int
(
context_lens
[
i
])
block_table
=
block_tables
[
i
]
# Create attention mask
attn_mask
=
torch
.
triu
(
torch
.
ones
(
query_len
,
context_len
),
diagonal
=
context_len
-
query_len
+
1
)
*
-
1e5
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
keys
=
[]
values
=
[]
for
j
in
range
(
context_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_heads
,
head_size
)
keys
.
append
(
k
)
v
=
value_cache
[
block_number
,
:,
:,
block_offset
]
values
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
keys
,
values
,
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_GEN_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
USE_ALIBI
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
run_single_query_cached_kv_attention
(
num_tokens
:
int
,
num_heads
:
int
,
def
test_single_query_cached_kv_attention
(
kv_cache_factory
,
num_seqs
:
int
,
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
num_kv_heads
:
int
=
None
,
seed
:
int
,
)
->
None
:
qkv
=
torch
.
empty
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
qkv
.
uniform_
(
-
1e-3
,
1e-3
)
query
,
_
,
_
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
empty
(
size
=
(
num_blocks
,
*
key_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
key_cache
.
uniform_
(
-
1e-3
,
1e-3
)
value_block_shape
=
(
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
empty
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
value_cache
.
uniform_
(
-
1e-3
,
1e-3
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_tokens
)]
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_query_heads
,
num_kv_heads
=
num_heads
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
head_mapping
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
num_queries_per_kv
)
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'
cuda
'
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"
cuda
"
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_
token
s
):
for
_
in
range
(
num_
seq
s
):
block_table
=
[
random
.
randint
(
0
,
num_blocks
-
1
)
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
head_mapping
=
torch
.
arange
(
num_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
num_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_heads
//
num_kv_heads
head_mapping
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_kv_heads
,
dtype
=
torch
.
int32
,
device
=
"cuda"
),
num_queries_per_kv
)
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
dtype
,
seed
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
output
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
...
...
@@ -227,45 +168,98 @@ def run_single_query_cached_kv_attention(
context_lens
,
block_size
,
max_context_len
,
None
,
# ALiBi
slopes
.
alibi_
slopes
,
)
# Run the reference implementation.
ref_output
=
torch
.
empty_like
(
query
)
ref_single_query_cached_kv_attention
(
ref_output
,
query
,
num_queries_per_kv
,
key_cache
,
value_cache
,
block_tables
,
context_lens
,
scale
,
alibi_slopes
,
)
# NOTE(woosuk): Due to the difference in the data types the two
#
implementations use for attention softmax logits and accumulation,
# there is a small difference in the
final outputs.
#
We should
use a relaxed tolerance for the test.
#
NOTE(woosuk): Due to the kernel-level differences in the two
#
implementations,
there is a small
numerical
difference in the
two
#
outputs. Thus, we
use a relaxed tolerance for the test.
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
List
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
"cuda"
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
run
_multi_query_kv_attention
(
def
test
_multi_query_kv_attention
(
num_seqs
:
int
,
num_heads
:
int
,
num_heads
:
Tuple
[
int
,
int
]
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
seq_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_seqs
)
num_tokens
=
sum
(
seq_lens
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_query_heads
,
num_kv_heads
=
num_heads
qkv
=
torch
.
empty
(
num_tokens
,
3
,
num_heads
,
num_query_heads
+
2
*
num_kv_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
qkv
.
uniform_
(
-
1e-3
,
1e-3
)
query
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
device
=
"cuda"
)
qkv
.
uniform_
(
-
scale
,
scale
)
query
,
key
,
value
=
qkv
.
split
(
[
num_query_heads
,
num_kv_heads
,
num_kv_heads
],
dim
=
1
)
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
output
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
...
...
@@ -285,40 +279,7 @@ def run_multi_query_kv_attention(
query
,
key
,
value
,
scale
,
dtype
,
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_single_query_cached_kv_attention
()
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
block_size
in
[
8
,
16
,
32
]:
for
head_size
in
[
64
,
80
,
96
,
112
,
128
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
run_single_query_cached_kv_attention
(
num_tokens
=
37
,
num_heads
=
3
,
head_size
=
head_size
,
block_size
=
block_size
,
num_blocks
=
1024
,
dtype
=
dtype
,
)
def
test_multi_query_kv_attention
()
->
None
:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
64
,
80
,
96
,
112
,
128
,
256
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
run_multi_query_kv_attention
(
num_seqs
=
5
,
num_heads
=
3
,
head_size
=
head_size
,
dtype
=
dtype
,
)
tests/kernels/test_cache.py
View file @
fbd80ad4
import
random
import
pytest
import
torch
from
vllm
import
cache_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
NUM_LAYERS
=
[
5
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
NUM_BLOCKS
=
[
1024
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
32
,
256
]
# Arbitrary values for testing
SEEDS
=
[
0
]
@
pytest
.
mark
.
parametrize
(
"num_mappings"
,
NUM_MAPPINGS
)
@
pytest
.
mark
.
parametrize
(
"num_layers"
,
NUM_LAYERS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
run_copy_blocks
(
def
test_copy_blocks
(
kv_cache_factory
,
num_mappings
:
int
,
num_layers
:
int
,
num_heads
:
int
,
...
...
@@ -14,48 +34,43 @@ def run_copy_blocks(
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
# Generate random block mappings.
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert
2
*
num_mappings
<=
num_blocks
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
num_mappings
)
block_mapping
=
{
src
:
[
dst
]
for
src
,
dst
in
zip
(
src_blocks
,
dst_blocks
)}
# Create the KV cache.
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
=
[]
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
randn
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
key_caches
.
append
(
key_cache
)
cloned_key_caches
=
[]
for
key_cache
in
key_caches
:
cloned_key_caches
.
append
(
key_cache
.
clone
())
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
=
[]
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
randn
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
value_caches
.
append
(
value_cache
)
cloned_value_caches
=
[]
for
value_cache
in
value_caches
:
cloned_value_caches
.
append
(
value_cache
.
clone
())
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
block_mapping
=
{}
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
dst1
=
dst_blocks
[
2
*
i
]
dst2
=
dst_blocks
[
2
*
i
+
1
]
block_mapping
[
src
]
=
[
dst1
,
dst2
]
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
num_layers
,
num_heads
,
head_size
,
dtype
,
seed
)
# Clone the KV caches.
cloned_key_caches
=
[
key_cache
.
clone
()
for
key_cache
in
key_caches
]
cloned_value_caches
=
[
value_cache
.
clone
()
for
value_cache
in
value_caches
]
# Call the copy blocks kernel.
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
# Reference implementation.
# R
un the r
eference implementation.
for
src
,
dsts
in
block_mapping
.
items
():
for
dst
in
dsts
:
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
]
=
cloned_key_cache
[
src
]
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_cache
[
dst
]
=
cloned_value_cache
[
src
]
# Compare the results.
...
...
@@ -66,15 +81,29 @@ def run_copy_blocks(
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
run_reshape_and_cache
(
def
test_reshape_and_cache
(
kv_cache_factory
,
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
...
...
@@ -87,110 +116,31 @@ def run_reshape_and_cache(
device
=
'cuda'
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
cloned_key_cache
=
key_cache
.
clone
()
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
1
,
num_heads
,
head_size
,
dtype
,
seed
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
randn
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
# Clone the KV caches.
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Call the reshape_and_cache kernel.
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'floor'
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
reshaped_key
=
key
.
reshape
(
num_tokens
,
num_heads
,
head_size
//
x
,
x
)
block_idx
=
torch
.
div
(
slot_mapping
[
i
],
block_size
,
rounding_mode
=
'floor'
)
block_offset
=
slot_mapping
[
i
]
%
block_size
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
torch
.
inference_mode
()
def
run_gather_cached_kv
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
qkv_clone
=
qkv
.
clone
()
_
,
cloned_key
,
cloned_value
=
qkv_clone
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
randn
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
cache_ops
.
gather_cached_kv
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
# Reference implementation.
for
i
in
range
(
num_tokens
):
reshaped_key
=
cloned_key
.
reshape
(
num_tokens
,
num_heads
,
head_size
//
x
,
x
)
block_idx
=
torch
.
div
(
slot_mapping
[
i
],
block_size
,
rounding_mode
=
'floor'
)
block_offset
=
slot_mapping
[
i
]
%
block_size
reshaped_key
[
i
]
=
key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
cloned_value
[
i
]
=
value_cache
[
block_idx
,
:,
:,
block_offset
]
assert
torch
.
allclose
(
key
,
cloned_key
)
assert
torch
.
allclose
(
value
,
cloned_value
)
def
test_copy_blocks
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_copy_blocks
(
num_mappings
=
23
,
num_layers
=
7
,
num_heads
=
17
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
1024
,
dtype
=
dtype
)
def
test_reshape_and_cache
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_reshape_and_cache
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
def
test_gather_cached_kv
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
run_gather_cached_kv
(
num_tokens
=
3
,
num_heads
=
2
,
head_size
=
16
,
block_size
=
8
,
num_blocks
=
2
,
dtype
=
dtype
)
tests/kernels/test_layernorm.py
View file @
fbd80ad4
import
pytest
import
torch
import
torch.nn
as
nn
from
vllm
import
layernorm_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
67
,
768
,
2048
,
5120
,
8192
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
SEEDS
=
[
0
]
class
RefRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
weight
=
torch
.
empty
(
hidden_size
)
weight
.
uniform_
(
-
1e-3
,
1e-3
)
weight
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
if
self
.
weight
.
dtype
in
[
torch
.
half
,
torch
.
float16
,
torch
.
bfloat16
]:
hidden_states
=
hidden_states
.
to
(
self
.
weight
.
dtype
)
return
self
.
weight
*
hidden_states
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
run
_rms_norm
(
def
test
_rms_norm
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
'cuda'
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
float
(
hidden_size
**-
0.5
)
x
=
torch
.
empty
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x
.
uniform_
(
-
scale
,
scale
)
ref
=
RefRMSNorm
(
hidden_size
).
to
(
dtype
).
cuda
()
out
=
torch
.
empty_like
(
x
)
...
...
@@ -40,17 +55,4 @@ def run_rms_norm(
ref
.
variance_epsilon
,
)
ref_out
=
ref
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_rms_norm
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
num_tokens
in
[
7
,
128
,
2048
]:
for
hidden_size
in
[
13
,
64
,
1024
,
5120
]:
print
(
f
'Testing RMS kernel with dtype=
{
dtype
}
, num_tokens='
f
'
{
num_tokens
}
, hidden_size=
{
hidden_size
}
'
)
run_rms_norm
(
num_tokens
=
num_tokens
,
hidden_size
=
hidden_size
,
dtype
=
dtype
,
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-5
)
tests/kernels/test_pos_encoding.py
View file @
fbd80ad4
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
import
pytest
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm
import
pos_encoding_ops
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
NUM_HEADS
=
[
7
,
12
,
40
,
52
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
SEEDS
=
[
0
]
def
rotate_half
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
...
...
@@ -74,16 +82,28 @@ class RefRotaryEmbeddingNeox(nn.Module):
return
query
,
key
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
run
_rotary_embedding_neox
(
def
test
_rotary_embedding_neox
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
max_position
:
int
,
rotary_dim
:
int
,
rotary_dim
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
if
rotary_dim
is
None
:
rotary_dim
=
head_size
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
num_tokens
,
),
device
=
'cuda'
)
query
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
...
...
@@ -97,7 +117,7 @@ def run_rotary_embedding_neox(
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'
i,j -> ij
'
,
t
,
inv_freq
.
float
())
freqs
=
torch
.
einsum
(
"
i,j -> ij
"
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
...
...
@@ -129,19 +149,5 @@ def run_rotary_embedding_neox(
ref_key
=
ref_key
.
view
(
num_tokens
,
num_heads
*
head_size
)
# Compare the results.
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_rotary_embedding_neox
()
->
None
:
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
run_rotary_embedding_neox
(
num_tokens
=
2145
,
num_heads
=
5
,
head_size
=
head_size
,
max_position
=
8192
,
rotary_dim
=
head_size
,
dtype
=
dtype
,
)
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-5
,
rtol
=
1e-5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment