Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7e9bd08f
Unverified
Commit
7e9bd08f
authored
Mar 13, 2024
by
Terry
Committed by
GitHub
Mar 13, 2024
Browse files
Add batched RoPE kernel (#3095)
parent
ae0ccb40
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
417 additions
and
37 deletions
+417
-37
benchmarks/kernels/benchmark_rope.py
benchmarks/kernels/benchmark_rope.py
+120
-0
csrc/ops.h
csrc/ops.h
+10
-0
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+111
-15
csrc/pybind.cpp
csrc/pybind.cpp
+5
-0
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+134
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+37
-21
No files found.
benchmarks/kernels/benchmark_rope.py
0 → 100644
View file @
7e9bd08f
from
typing
import
Optional
import
argparse
import
torch
import
nvtx
from
itertools
import
accumulate
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
def
benchmark_rope_kernels_multi_lora
(
is_neox_style
:
bool
,
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
,
rotary_dim
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
# silulating serving 4 LoRAs
scaling_factors
=
[
1
,
2
,
4
,
8
]
# batched RoPE can take multiple scaling factors
batched_rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)
})
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes
=
[]
for
scaling_factor
in
scaling_factors
:
non_batched_ropes
.
append
(
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"type"
:
"linear"
,
"factor"
:
(
scaling_factor
,
)
}))
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
offset_map
=
torch
.
tensor
(
list
(
accumulate
([
0
]
+
[
max_position
*
scaling_factor
*
2
for
scaling_factor
in
scaling_factors
[:
-
1
]
])))
query_types
=
torch
.
randint
(
0
,
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
device
=
device
)
# map query types to offsets
query_offsets
=
offset_map
[
query_types
]
# the kernel takes flattened offsets
flatten_offsets
=
query_offsets
.
flatten
()
# batched queries of the same type together for non-batched RoPE
queries
=
[
query
[
query_types
==
i
]
for
i
in
range
(
len
(
scaling_factors
))]
keys
=
[
key
[
query_types
==
i
]
for
i
in
range
(
len
(
scaling_factors
))]
packed_qkr
=
zip
(
queries
,
keys
,
non_batched_ropes
)
# synchronize before start timing
torch
.
cuda
.
synchronize
()
with
nvtx
.
annotate
(
"non-batched"
,
color
=
"yellow"
):
for
q
,
k
,
r
in
packed_qkr
:
r
.
forward
(
positions
,
q
,
k
)
torch
.
cuda
.
synchronize
()
with
nvtx
.
annotate
(
"batched"
,
color
=
"green"
):
batched_rope
.
forward
(
positions
,
query
,
key
,
flatten_offsets
)
torch
.
cuda
.
synchronize
()
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the rotary embedding kernels."
)
parser
.
add_argument
(
"--is-neox-style"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
128
,
256
],
default
=
128
)
parser
.
add_argument
(
"--rotary-dim"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
32
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"bfloat16"
,
"float"
],
default
=
"float"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
)
args
=
parser
.
parse_args
()
print
(
args
)
benchmark_rope_kernels_multi_lora
(
is_neox_style
=
args
.
is_neox_style
,
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
rotary_dim
=
args
.
rotary_dim
,
dtype
=
getattr
(
torch
,
args
.
dtype
),
seed
=
args
.
seed
,
device
=
args
.
device
,
)
csrc/ops.h
View file @
7e9bd08f
...
...
@@ -53,6 +53,16 @@ void rotary_embedding(
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
csrc/pos_encoding_kernels.cu
View file @
7e9bd08f
...
...
@@ -8,7 +8,7 @@
namespace
vllm
{
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding
(
inline
__device__
void
apply_
token_
rotary_embedding
(
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
...
...
@@ -38,22 +38,18 @@ inline __device__ void apply_rotary_embedding(
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or [num_tokens]
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
...
...
@@ -63,7 +59,7 @@ __global__ void rotary_embedding_kernel(
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
apply_
token_
rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
...
...
@@ -72,11 +68,53 @@ __global__ void rotary_embedding_kernel(
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
apply_
token_
rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
batched_rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len] or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
cos_sin_cache_offset
=
cos_sin_cache_offsets
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
(
cos_sin_cache_offset
+
pos
)
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
}
// namespace vllm
void
rotary_embedding
(
...
...
@@ -128,3 +166,61 @@ void rotary_embedding(
}
});
}
/*
Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens]
)
{
int64_t
num_tokens
=
cos_sin_cache_offsets
.
size
(
0
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int64_t
query_stride
=
query
.
stride
(
-
2
);
int64_t
key_stride
=
key
.
stride
(
-
2
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
if
(
is_neox
)
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
csrc/pybind.cpp
View file @
7e9bd08f
...
...
@@ -56,6 +56,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
ops
.
def
(
"batched_rotary_embedding"
,
&
batched_rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"
);
// Quantization ops
#ifndef USE_ROCM
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
...
...
tests/kernels/test_pos_encoding.py
View file @
7e9bd08f
from
typing
import
Optional
from
typing
import
List
,
Optional
import
pytest
import
torch
from
allclose_default
import
get_default_atol
,
get_default_rtol
from
itertools
import
accumulate
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
IS_NEOX_STYLE
=
[
True
,
False
]
...
...
@@ -72,3 +73,135 @@ def test_rotary_embedding(
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_batched_rotary_embedding
(
is_neox_style
:
bool
,
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
,
rotary_dim
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"type"
:
"linear"
,
"factor"
:
(
1
,
)
})
rope
=
rope
.
to
(
dtype
=
dtype
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
_forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
dtype
=
int
,
device
=
device
))
# Compare the results.
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_batched_rotary_embedding_multi_lora
(
is_neox_style
:
bool
,
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
,
rotary_dim
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
scaling_factors
:
List
[
int
]
=
[
1
,
2
,
4
]
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)
})
rope
=
rope
.
to
(
dtype
=
dtype
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
offset_map
=
torch
.
tensor
(
list
(
accumulate
([
0
]
+
[
max_position
*
scaling_factor
*
2
for
scaling_factor
in
scaling_factors
[:
-
1
]
])))
query_types
=
torch
.
randint
(
0
,
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
device
=
device
)
query_offsets
=
offset_map
[
query_types
]
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
_forward
(
positions
,
query
,
key
,
query_offsets
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
query_offsets
.
flatten
())
# Compare the results.
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
vllm/model_executor/layers/rotary_embedding.py
View file @
7e9bd08f
...
...
@@ -22,7 +22,7 @@
# limitations under the License.
"""Rotary Positional Embeddings."""
import
math
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -96,6 +96,7 @@ class RotaryEmbedding(nn.Module):
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
...
...
@@ -107,7 +108,9 @@ class RotaryEmbedding(nn.Module):
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
get_device
())
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
...
...
@@ -137,11 +140,19 @@ class RotaryEmbedding(nn.Module):
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# ops.rotary_embedding() is an in-place operation that
# updates the query and key tensors.
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
get_device
())
# ops.rotary_embedding()/batched_rotary_embedding() are in-place operations that
# update the query and key tensors.
if
offsets
is
not
None
:
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
...
...
@@ -158,27 +169,32 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
scaling_factor
s
:
Union
[
List
[
float
],
float
]
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
if
isinstance
(
scaling_factors
,
float
):
scaling_factors
=
[
scaling_factors
]
self
.
scaling_factors
=
scaling_factors
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len
=
self
.
max_position_embeddings
*
self
.
scaling_factor
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
t
=
t
/
self
.
scaling_factor
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
cache_list
=
[]
for
scaling_factor
in
self
.
scaling_factors
:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len
=
self
.
max_position_embeddings
*
scaling_factor
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
t
=
t
/
scaling_factor
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache_list
.
append
(
cache
)
return
torch
.
cat
(
cache_list
,
dim
=
0
)
class
DynamicNTKScalingRotaryEmbedding
(
RotaryEmbedding
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment