Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
88c0268a
Unverified
Commit
88c0268a
authored
Mar 30, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 30, 2023
Browse files
Implement custom kernel for LLaMA rotary embedding (#14)
parent
80a2f812
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
318 additions
and
69 deletions
+318
-69
cacheflow/models/attention.py
cacheflow/models/attention.py
+71
-2
cacheflow/models/llama.py
cacheflow/models/llama.py
+5
-59
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+1
-2
cacheflow/models/opt.py
cacheflow/models/opt.py
+1
-2
cacheflow/models/sample.py
cacheflow/models/sample.py
+1
-1
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+3
-3
csrc/pos_encoding.cpp
csrc/pos_encoding.cpp
+16
-0
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+83
-0
setup.py
setup.py
+8
-0
tests/kernels/pos_encoding.py
tests/kernels/pos_encoding.py
+129
-0
No files found.
cacheflow/models/attention.py
View file @
88c0268a
...
@@ -6,13 +6,14 @@ import torch.nn as nn
...
@@ -6,13 +6,14 @@ import torch.nn as nn
from
cacheflow
import
attention_ops
from
cacheflow
import
attention_ops
from
cacheflow
import
cache_ops
from
cacheflow
import
cache_ops
from
cacheflow
import
pos_encoding_ops
from
cacheflow.models
import
InputMetadata
from
cacheflow.models
import
InputMetadata
class
O
PTCacheFlowAttention
(
nn
.
Module
):
class
G
PTCacheFlowAttention
(
nn
.
Module
):
def
__init__
(
self
,
scale
:
float
)
->
None
:
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
(
OPTCacheFlowAttention
,
self
).
__init__
()
super
().
__init__
()
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
flash_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
self
.
flash_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
...
@@ -136,3 +137,71 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -136,3 +137,71 @@ class OPTCacheFlowAttention(nn.Module):
# Reshape the output tensor.
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
# NOTE(woosuk): The output tensor may include paddings.
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
class
OPTCacheFlowAttention
(
GPTCacheFlowAttention
):
"""OPT uses the same attention mechanism as GPT."""
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
(
scale
)
class
LlamaCacheFlowAttention
(
GPTCacheFlowAttention
):
"""Llama uses GPT-NeoX style rotary embedding."""
def
__init__
(
self
,
scale
:
float
,
head_size
:
int
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
(
scale
)
# Create the cos and sin cache.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
head_size
,
2
)
/
head_size
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'i,j -> ij'
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust.
torch_dtype
=
torch
.
get_default_dtype
()
cache
=
cache
.
to
(
torch_dtype
)
# Embedding size: [max_position, head_size]
self
.
register_buffer
(
'cos_sin_cache'
,
cache
,
persistent
=
False
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
,
# [num_tokens]
query
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
value
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size/x, block_size, x]
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size, block_size]
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query
=
torch
.
empty_like
(
query
)
out_key
=
torch
.
empty_like
(
key
)
pos_encoding_ops
.
rotary_embedding_neox
(
out_query
,
out_key
,
positions
,
query
,
key
,
self
.
cos_sin_cache
,
)
return
super
().
forward
(
out_query
,
out_key
,
value
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
,
)
cacheflow/models/llama.py
View file @
88c0268a
...
@@ -8,12 +8,10 @@ from typing import Dict, List, Optional, Tuple
...
@@ -8,12 +8,10 @@ from typing import Dict, List, Optional, Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
transformers
import
PreTrainedModel
from
cacheflow.models
import
InputMetadata
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
OPT
CacheFlowAttention
from
cacheflow.models.attention
import
Llama
CacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.sample
import
Sampler
from
cacheflow.parallel_utils.parallel_state
import
(
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
@@ -41,48 +39,8 @@ class LlamaRMSNorm(nn.Module):
...
@@ -41,48 +39,8 @@ class LlamaRMSNorm(nn.Module):
return
self
.
weight
*
hidden_states
return
self
.
weight
*
hidden_states
class
LlamaRotaryEmbedding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
):
super
().
__init__
()
self
.
max_position_embeddings
=
max_position_embeddings
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
# Create cos and sin embeddings.
t
=
torch
.
arange
(
max_position_embeddings
).
float
()
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
.
float
())
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
emb
.
cos
().
to
(
dtype
=
self
.
inv_freq
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
self
.
inv_freq
.
dtype
)
self
.
register_buffer
(
"cos_cached"
,
cos
,
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
sin
,
persistent
=
False
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
return
cos
,
sin
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
):
# TODO: Optimize.
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
LlamaMLP
(
nn
.
Module
):
class
LlamaMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -156,9 +114,7 @@ class LlamaAttention(nn.Module):
...
@@ -156,9 +114,7 @@ class LlamaAttention(nn.Module):
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
perform_initialization
=
False
,
)
)
self
.
rotary_emb
=
LlamaRotaryEmbedding
(
self
.
head_dim
)
self
.
attn
=
LlamaCacheFlowAttention
(
self
.
scaling
,
self
.
head_dim
)
# FIXME(woosuk): Rename this.
self
.
attn
=
OPTCacheFlowAttention
(
scale
=
self
.
scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -171,19 +127,9 @@ class LlamaAttention(nn.Module):
...
@@ -171,19 +127,9 @@ class LlamaAttention(nn.Module):
q
,
_
=
self
.
q_proj
(
hidden_states
)
q
,
_
=
self
.
q_proj
(
hidden_states
)
k
,
_
=
self
.
k_proj
(
hidden_states
)
k
,
_
=
self
.
k_proj
(
hidden_states
)
v
,
_
=
self
.
v_proj
(
hidden_states
)
v
,
_
=
self
.
v_proj
(
hidden_states
)
k_cache
,
v_cache
=
kv_cache
# Apply rotrary embedding.
# TODO: Optimize.
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
k
=
k
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
cos
,
sin
=
self
.
rotary_emb
(
positions
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
=
q
.
transpose
(
0
,
1
).
contiguous
().
view
(
-
1
,
self
.
num_heads
*
self
.
head_dim
)
k
=
k
.
transpose
(
0
,
1
).
contiguous
().
view
(
-
1
,
self
.
num_heads
*
self
.
head_dim
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k
ey
_cache
,
v
alue
_cache
,
input_metadata
,
cache_event
)
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
cacheflow/models/memory_analyzer.py
View file @
88c0268a
...
@@ -165,8 +165,7 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
...
@@ -165,8 +165,7 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self
.
head_size
=
config
.
hidden_size
//
self
.
num_heads
self
.
head_size
=
config
.
hidden_size
//
self
.
num_heads
self
.
ffn_size
=
config
.
intermediate_size
self
.
ffn_size
=
config
.
intermediate_size
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
# FIXME
self
.
max_position
=
8192
self
.
max_position
=
2048
def
_get_param_size
(
self
)
->
int
:
def
_get_param_size
(
self
)
->
int
:
word_embedding
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
word_embedding
=
self
.
vocab_size
*
self
.
hidden_size
//
self
.
tensor_parallel_size
...
...
cacheflow/models/opt.py
View file @
88c0268a
...
@@ -51,7 +51,7 @@ class OPTAttention(nn.Module):
...
@@ -51,7 +51,7 @@ class OPTAttention(nn.Module):
assert
num_heads
%
tensor_model_parallel_world_size
==
0
assert
num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
embed_dim
//
total_num_heads
self
.
head_dim
=
embed_dim
//
total_num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**
-
0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self
.
k_proj
=
ColumnParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
self
.
k_proj
=
ColumnParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
...
@@ -66,7 +66,6 @@ class OPTAttention(nn.Module):
...
@@ -66,7 +66,6 @@ class OPTAttention(nn.Module):
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
attn
=
OPTCacheFlowAttention
(
scale
=
self
.
scaling
)
self
.
attn
=
OPTCacheFlowAttention
(
scale
=
self
.
scaling
)
def
forward
(
def
forward
(
...
...
cacheflow/models/sample.py
View file @
88c0268a
...
@@ -12,7 +12,7 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
...
@@ -12,7 +12,7 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
(
Sampler
,
self
).
__init__
()
super
().
__init__
()
def
forward
(
def
forward
(
self
,
self
,
...
...
csrc/cache_kernels.cu
View file @
88c0268a
...
@@ -122,13 +122,13 @@ void reshape_and_cache(
...
@@ -122,13 +122,13 @@ void reshape_and_cache(
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
)
{
torch
::
Tensor
&
slot_mapping
)
{
int
num_tokens
=
key
.
size
(
0
);
int
num_tokens
=
key
.
size
(
0
);
int
head_num
=
key
.
size
(
1
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
int
x
=
key_cache
.
size
(
4
);
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
head_num
*
head_size
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
key
.
scalar_type
(),
key
.
scalar_type
(),
...
@@ -140,7 +140,7 @@ void reshape_and_cache(
...
@@ -140,7 +140,7 @@ void reshape_and_cache(
key_cache
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
head_num
,
num_heads
,
head_size
,
head_size
,
block_size
,
block_size
,
x
);
x
);
...
...
csrc/pos_encoding.cpp
0 → 100644
View file @
88c0268a
#include <torch/extension.h>
void
rotary_embedding_neox
(
torch
::
Tensor
&
out_query
,
torch
::
Tensor
&
out_key
,
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
torch
::
Tensor
&
cos_sin_cache
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
}
csrc/pos_encoding_kernels.cu
0 → 100644
View file @
88c0268a
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
namespace
cacheflow
{
template
<
typename
scalar_t
>
__global__
void
rotary_embedding_neox_kernel
(
scalar_t
*
__restrict__
out_query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
out_key
,
// [num_tokens, num_heads, head_size]
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
const
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, head_size // 2]
const
int
num_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
head_size
;
const
int
embed_dim
=
head_size
/
2
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
idx
=
token_idx
*
n
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
token_head
=
token_idx
*
n
+
head_idx
*
head_size
;
const
bool
is_first_half
=
head_offset
<
embed_dim
;
const
int
rot_offset
=
head_offset
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
__ldg
(
query
+
token_head
+
x_index
);
const
scalar_t
q_y
=
__ldg
(
query
+
token_head
+
y_index
);
const
scalar_t
q_cos
=
is_first_half
?
q_x
:
q_y
;
const
scalar_t
q_sin
=
is_first_half
?
-
q_y
:
q_x
;
out_query
[
idx
]
=
q_cos
*
cos
+
q_sin
*
sin
;
const
scalar_t
k_x
=
__ldg
(
key
+
token_head
+
x_index
);
const
scalar_t
k_y
=
__ldg
(
key
+
token_head
+
y_index
);
const
scalar_t
k_cos
=
is_first_half
?
k_x
:
k_y
;
const
scalar_t
k_sin
=
is_first_half
?
-
k_y
:
k_x
;
out_key
[
idx
]
=
k_cos
*
cos
+
k_sin
*
sin
;
}
}
}
// namespace cacheflow
void
rotary_embedding_neox
(
torch
::
Tensor
&
out_query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
out_key
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, head_size]
{
int
num_tokens
=
query
.
size
(
0
);
int
head_size
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
query
.
scalar_type
(),
"rotary_embedding_neox"
,
[
&
]
{
cacheflow
::
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out_query
.
data_ptr
<
scalar_t
>
(),
out_key
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
num_heads
,
head_size
);
});
}
setup.py
View file @
88c0268a
...
@@ -23,6 +23,14 @@ attention_extension = cpp_extension.CUDAExtension(
...
@@ -23,6 +23,14 @@ attention_extension = cpp_extension.CUDAExtension(
)
)
ext_modules
.
append
(
attention_extension
)
ext_modules
.
append
(
attention_extension
)
# Positional encodings.
positional_encoding_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.pos_encoding_ops'
,
sources
=
[
'csrc/pos_encoding.cpp'
,
'csrc/pos_encoding_kernels.cu'
],
extra_compile_args
=
{
'cxx'
:
CXX_FLAGS
,
'nvcc'
:
NVCC_FLAGS
},
)
ext_modules
.
append
(
positional_encoding_extension
)
setuptools
.
setup
(
setuptools
.
setup
(
name
=
'cacheflow'
,
name
=
'cacheflow'
,
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
...
...
tests/kernels/pos_encoding.py
0 → 100644
View file @
88c0268a
from
typing
import
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
cacheflow
import
pos_encoding_ops
def
rotate_half
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
RefRotaryEmbeddingNeox
(
nn
.
Module
):
"""Reference implementation of the GPT-NeoX style rotary embedding."""
def
__init__
(
self
,
dim
:
int
,
max_position_embeddings
:
int
=
2048
,
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
()
self
.
max_position_embeddings
=
max_position_embeddings
# Create cos and sin embeddings.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
t
=
torch
.
arange
(
max_position_embeddings
).
float
()
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
inv_freq
.
float
())
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
emb
.
cos
().
to
(
dtype
=
inv_freq
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
inv_freq
.
dtype
)
self
.
register_buffer
(
"cos_cached"
,
cos
,
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
sin
,
persistent
=
False
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
,
# [num_tokens]
query
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
query
=
query
.
transpose
(
0
,
1
)
key
=
key
.
transpose
(
0
,
1
)
query
,
key
=
apply_rotary_pos_emb
(
query
,
key
,
cos
,
sin
)
query
=
query
.
transpose
(
0
,
1
).
contiguous
()
key
=
key
.
transpose
(
0
,
1
).
contiguous
()
# Output query/key shape: [num_tokens, num_tokens, head_size]
return
query
,
key
@
torch
.
inference_mode
()
def
test_rotary_embedding_neox
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
max_position
:
int
,
dtype
:
torch
.
dtype
,
base
:
int
=
10000
,
)
->
None
:
positions
=
torch
.
randint
(
0
,
max_position
,
(
num_tokens
,),
device
=
'cuda'
)
query
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
key
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
head_size
,
2
)
/
head_size
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
'i,j -> ij'
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cos_sin_cache
=
cos_sin_cache
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
# Run the kernel.
out_query
=
torch
.
empty_like
(
query
)
out_key
=
torch
.
empty_like
(
key
)
pos_encoding_ops
.
rotary_embedding_neox
(
out_query
,
out_key
,
positions
,
query
,
key
,
cos_sin_cache
,
)
# Run the reference implementation.
ref_rotary_embedding
=
RefRotaryEmbeddingNeox
(
dim
=
head_size
,
max_position_embeddings
=
max_position
,
base
=
base
,
).
to
(
dtype
=
dtype
,
device
=
'cuda'
)
ref_query
,
ref_key
=
ref_rotary_embedding
(
positions
,
query
.
view
(
num_tokens
,
num_heads
,
head_size
),
key
.
view
(
num_tokens
,
num_heads
,
head_size
),
)
ref_query
=
ref_query
.
view
(
num_tokens
,
num_heads
*
head_size
)
ref_key
=
ref_key
.
view
(
num_tokens
,
num_heads
*
head_size
)
# Compare the results.
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-3
,
rtol
=
1e-5
)
if
__name__
==
'__main__'
:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Running tests for head_size=
{
head_size
}
and dtype=
{
dtype
}
'
)
test_rotary_embedding_neox
(
num_tokens
=
2145
,
num_heads
=
5
,
head_size
=
head_size
,
max_position
=
8192
,
dtype
=
dtype
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment