Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
88c0268a
Unverified
Commit
88c0268a
authored
Mar 30, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 30, 2023
Browse files
Implement custom kernel for LLaMA rotary embedding (#14)
parent
80a2f812
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
318 additions
and
69 deletions
+318
-69
cacheflow/models/attention.py
cacheflow/models/attention.py
+71
-2
cacheflow/models/llama.py
cacheflow/models/llama.py
+5
-59
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+1
-2
cacheflow/models/opt.py
cacheflow/models/opt.py
+1
-2
cacheflow/models/sample.py
cacheflow/models/sample.py
+1
-1
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+3
-3
csrc/pos_encoding.cpp
csrc/pos_encoding.cpp
+16
-0
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+83
-0
setup.py
setup.py
+8
-0
tests/kernels/pos_encoding.py
tests/kernels/pos_encoding.py
+129
-0
No files found.
cacheflow/models/attention.py
View file @
88c0268a
...
...
@@ -6,13 +6,14 @@ import torch.nn as nn
from
cacheflow
import
attention_ops
from
cacheflow
import
cache_ops
from
cacheflow
import
pos_encoding_ops
from
cacheflow.models
import
InputMetadata
class
O
PTCacheFlowAttention
(
nn
.
Module
):
class
G
PTCacheFlowAttention
(
nn
.
Module
):
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
(
OPTCacheFlowAttention
,
self
).
__init__
()
super
().
__init__
()
self
.
scale
=
float
(
scale
)
self
.
flash_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
...
...
@@ -136,3 +137,71 @@ class OPTCacheFlowAttention(nn.Module):
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
class
OPTCacheFlowAttention
(
GPTCacheFlowAttention
):
"""OPT uses the same attention mechanism as GPT."""
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
(
scale
)
class
LlamaCacheFlowAttention
(
GPTCacheFlowAttention
):
"""Llama uses GPT-NeoX style rotary embedding."""
def
__init__
(
self
,
scale
:
float
,
head_size
:
int
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
(
scale
)
# Create the cos and sin cache.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
head_size
,
2
)
/
head_size
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'i,j -> ij'
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust.
torch_dtype
=
torch
.
get_default_dtype
()
cache
=
cache
.
to
(
torch_dtype
)
# Embedding size: [max_position, head_size]
self
.
register_buffer
(
'cos_sin_cache'
,
cache
,
persistent
=
False
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
,
# [num_tokens]
query
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
value
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size/x, block_size, x]
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size, block_size]
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query
=
torch
.
empty_like
(
query
)
out_key
=
torch
.
empty_like
(
key
)
pos_encoding_ops
.
rotary_embedding_neox
(
out_query
,
out_key
,
positions
,
query
,
key
,
self
.
cos_sin_cache
,
)
return
super
().
forward
(
out_query
,
out_key
,
value
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
,
)
cacheflow/models/llama.py
View file @
88c0268a
...
...
@@ -8,12 +8,10 @@ from typing import Dict, List, Optional, Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
transformers
import
LlamaConfig
from
transformers
import
PreTrainedModel
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
OPT
CacheFlowAttention
from
cacheflow.models.attention
import
Llama
CacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
...
@@ -41,48 +39,8 @@ class LlamaRMSNorm(nn.Module):
return
self
.
weight
*
hidden_states
class
LlamaRotaryEmbedding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
):
super
().
__init__
()
self
.
max_position_embeddings
=
max_position_embeddings
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
# Create cos and sin embeddings.
t
=
torch
.
arange
(
max_position_embeddings
).
float
()
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
.
float
())
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
emb
.
cos
().
to
(
dtype
=
self
.
inv_freq
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
self
.
inv_freq
.
dtype
)
self
.
register_buffer
(
"cos_cached"
,
cos
,
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
sin
,
persistent
=
False
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
return
cos
,
sin
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
):
# TODO: Optimize.
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
LlamaMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
...
...
@@ -156,9 +114,7 @@ class LlamaAttention(nn.Module):
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
self
.
rotary_emb
=
LlamaRotaryEmbedding
(
self
.
head_dim
)
# FIXME(woosuk): Rename this.
self
.
attn
=
OPTCacheFlowAttention
(
scale
=
self
.
scaling
)
self
.
attn
=
LlamaCacheFlowAttention
(
self
.
scaling
,
self
.
head_dim
)
def
forward
(
self
,
...
...
@@ -171,19 +127,9 @@ class LlamaAttention(nn.Module):
q
,
_
=
self
.
q_proj
(
hidden_states
)
k
,
_
=
self
.
k_proj
(
hidden_states
)
v
,
_
=
self
.
v_proj
(
hidden_states
)
# Apply rotrary embedding.
# TODO: Optimize.
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
k
=
k
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
cos
,
sin
=
self
.
rotary_emb
(
positions
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
=
q
.
transpose
(
0
,
1
).
contiguous
().
view
(
-
1
,
self
.
num_heads
*
self
.
head_dim
)
k
=
k
.
transpose
(
0
,
1
).
contiguous
().
view
(
-
1
,
self
.
num_heads
*
self
.
head_dim
)
key_cache
,
value_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k
ey
_cache
,
v
alue
_cache
,
input_metadata
,
cache_event
)
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
cacheflow/models/memory_analyzer.py
View file @
88c0268a
...
...
@@ -165,8 +165,7 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self
.
head_size
=
config
.
hidden_size
//
self
.
num_heads
self
.
ffn_size
=
config
.
intermediate_size
self
.
vocab_size
=
config
.
vocab_size
# FIXME
self
.
max_position
=
2048
self
.
max_position
=
8192
def
_get_param_size
(
self
)
->
int
:
word_embedding
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
...
...
cacheflow/models/opt.py
View file @
88c0268a
...
...
@@ -51,7 +51,7 @@ class OPTAttention(nn.Module):
assert
num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
embed_dim
//
total_num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**
-
0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self
.
k_proj
=
ColumnParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
...
...
@@ -66,7 +66,6 @@ class OPTAttention(nn.Module):
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
attn
=
OPTCacheFlowAttention
(
scale
=
self
.
scaling
)
def
forward
(
...
...
cacheflow/models/sample.py
View file @
88c0268a
...
...
@@ -12,7 +12,7 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
(
Sampler
,
self
).
__init__
()
super
().
__init__
()
def
forward
(
self
,
...
...
csrc/cache_kernels.cu
View file @
88c0268a
...
...
@@ -122,13 +122,13 @@ void reshape_and_cache(
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
)
{
int
num_tokens
=
key
.
size
(
0
);
int
head_num
=
key
.
size
(
1
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
head_num
*
head_size
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
key
.
scalar_type
(),
...
...
@@ -140,7 +140,7 @@ void reshape_and_cache(
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
head_num
,
num_heads
,
head_size
,
block_size
,
x
);
...
...
csrc/pos_encoding.cpp
0 → 100644
View file @
88c0268a
#include <torch/extension.h>
void
rotary_embedding_neox
(
torch
::
Tensor
&
out_query
,
torch
::
Tensor
&
out_key
,
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
torch
::
Tensor
&
cos_sin_cache
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
}
csrc/pos_encoding_kernels.cu
0 → 100644
View file @
88c0268a
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
namespace
cacheflow
{
template
<
typename
scalar_t
>
__global__
void
rotary_embedding_neox_kernel
(
scalar_t
*
__restrict__
out_query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
out_key
,
// [num_tokens, num_heads, head_size]
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
const
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, head_size // 2]
const
int
num_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
head_size
;
const
int
embed_dim
=
head_size
/
2
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
idx
=
token_idx
*
n
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
token_head
=
token_idx
*
n
+
head_idx
*
head_size
;
const
bool
is_first_half
=
head_offset
<
embed_dim
;
const
int
rot_offset
=
head_offset
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
__ldg
(
query
+
token_head
+
x_index
);
const
scalar_t
q_y
=
__ldg
(
query
+
token_head
+
y_index
);
const
scalar_t
q_cos
=
is_first_half
?
q_x
:
q_y
;
const
scalar_t
q_sin
=
is_first_half
?
-
q_y
:
q_x
;
out_query
[
idx
]
=
q_cos
*
cos
+
q_sin
*
sin
;
const
scalar_t
k_x
=
__ldg
(
key
+
token_head
+
x_index
);
const
scalar_t
k_y
=
__ldg
(
key
+
token_head
+
y_index
);
const
scalar_t
k_cos
=
is_first_half
?
k_x
:
k_y
;
const
scalar_t
k_sin
=
is_first_half
?
-
k_y
:
k_x
;
out_key
[
idx
]
=
k_cos
*
cos
+
k_sin
*
sin
;
}
}
}
// namespace cacheflow
void
rotary_embedding_neox
(
torch
::
Tensor
&
out_query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
out_key
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, head_size]
{
int
num_tokens
=
query
.
size
(
0
);
int
head_size
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
query
.
scalar_type
(),
"rotary_embedding_neox"
,
[
&
]
{
cacheflow
::
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out_query
.
data_ptr
<
scalar_t
>
(),
out_key
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
num_heads
,
head_size
);
});
}
setup.py
View file @
88c0268a
...
...
@@ -23,6 +23,14 @@ attention_extension = cpp_extension.CUDAExtension(
)
ext_modules
.
append
(
attention_extension
)
# Positional encodings.
positional_encoding_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.pos_encoding_ops'
,
sources
=
[
'csrc/pos_encoding.cpp'
,
'csrc/pos_encoding_kernels.cu'
],
extra_compile_args
=
{
'cxx'
:
CXX_FLAGS
,
'nvcc'
:
NVCC_FLAGS
},
)
ext_modules
.
append
(
positional_encoding_extension
)
setuptools
.
setup
(
name
=
'cacheflow'
,
ext_modules
=
ext_modules
,
...
...
tests/kernels/pos_encoding.py
0 → 100644
View file @
88c0268a
from
typing
import
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
cacheflow
import
pos_encoding_ops
def
rotate_half
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
RefRotaryEmbeddingNeox
(
nn
.
Module
):
"""Reference implementation of the GPT-NeoX style rotary embedding."""
def
__init__
(
self
,
dim
:
int
,
max_position_embeddings
:
int
=
2048
,
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
()
self
.
max_position_embeddings
=
max_position_embeddings
# Create cos and sin embeddings.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
t
=
torch
.
arange
(
max_position_embeddings
).
float
()
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
inv_freq
.
float
())
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
emb
.
cos
().
to
(
dtype
=
inv_freq
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
inv_freq
.
dtype
)
self
.
register_buffer
(
"cos_cached"
,
cos
,
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
sin
,
persistent
=
False
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
,
# [num_tokens]
query
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
query
=
query
.
transpose
(
0
,
1
)
key
=
key
.
transpose
(
0
,
1
)
query
,
key
=
apply_rotary_pos_emb
(
query
,
key
,
cos
,
sin
)
query
=
query
.
transpose
(
0
,
1
).
contiguous
()
key
=
key
.
transpose
(
0
,
1
).
contiguous
()
# Output query/key shape: [num_tokens, num_tokens, head_size]
return
query
,
key
@
torch
.
inference_mode
()
def
test_rotary_embedding_neox
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
max_position
:
int
,
dtype
:
torch
.
dtype
,
base
:
int
=
10000
,
)
->
None
:
positions
=
torch
.
randint
(
0
,
max_position
,
(
num_tokens
,),
device
=
'cuda'
)
query
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
head_size
,
2
)
/
head_size
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'i,j -> ij'
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cos_sin_cache
=
cos_sin_cache
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
# Run the kernel.
out_query
=
torch
.
empty_like
(
query
)
out_key
=
torch
.
empty_like
(
key
)
pos_encoding_ops
.
rotary_embedding_neox
(
out_query
,
out_key
,
positions
,
query
,
key
,
cos_sin_cache
,
)
# Run the reference implementation.
ref_rotary_embedding
=
RefRotaryEmbeddingNeox
(
dim
=
head_size
,
max_position_embeddings
=
max_position
,
base
=
base
,
).
to
(
dtype
=
dtype
,
device
=
'cuda'
)
ref_query
,
ref_key
=
ref_rotary_embedding
(
positions
,
query
.
view
(
num_tokens
,
num_heads
,
head_size
),
key
.
view
(
num_tokens
,
num_heads
,
head_size
),
)
ref_query
=
ref_query
.
view
(
num_tokens
,
num_heads
*
head_size
)
ref_key
=
ref_key
.
view
(
num_tokens
,
num_heads
*
head_size
)
# Compare the results.
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-3
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
test_rotary_embedding_neox
(
num_tokens
=
2145
,
num_heads
=
5
,
head_size
=
head_size
,
max_position
=
8192
,
dtype
=
dtype
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment