Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6ad6b8e1
Unverified
Commit
6ad6b8e1
authored
Aug 04, 2025
by
TJian
Committed by
GitHub
Aug 04, 2025
Browse files
[FEAT] Refactor ROPE into module (#22192)
Signed-off-by:
tjtanaa
<
tunjian.tan@embeddedllm.com
>
parent
f4f4e7ef
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
2111 additions
and
0 deletions
+2111
-0
vllm/model_executor/layers/rotary_embedding/__init__.py
vllm/model_executor/layers/rotary_embedding/__init__.py
+190
-0
vllm/model_executor/layers/rotary_embedding/base.py
vllm/model_executor/layers/rotary_embedding/base.py
+237
-0
vllm/model_executor/layers/rotary_embedding/common.py
vllm/model_executor/layers/rotary_embedding/common.py
+105
-0
vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
...executor/layers/rotary_embedding/deepseek_scaling_rope.py
+131
-0
vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py
...model_executor/layers/rotary_embedding/dual_chunk_rope.py
+188
-0
vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py
...xecutor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py
+41
-0
vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py
...cutor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py
+67
-0
vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py
...l_executor/layers/rotary_embedding/linear_scaling_rope.py
+115
-0
vllm/model_executor/layers/rotary_embedding/llama3_rope.py
vllm/model_executor/layers/rotary_embedding/llama3_rope.py
+54
-0
vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
...el_executor/layers/rotary_embedding/llama4_vision_rope.py
+74
-0
vllm/model_executor/layers/rotary_embedding/mrope.py
vllm/model_executor/layers/rotary_embedding/mrope.py
+670
-0
vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py
...odel_executor/layers/rotary_embedding/ntk_scaling_rope.py
+42
-0
vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
...tor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
+129
-0
vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py
...del_executor/layers/rotary_embedding/yarn_scaling_rope.py
+68
-0
No files found.
vllm/model_executor/layers/rotary_embedding/__init__.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Rotary Positional Embeddings."""
from
typing
import
Any
,
Optional
import
torch
from
.base
import
RotaryEmbedding
from
.deepseek_scaling_rope
import
DeepseekScalingRotaryEmbedding
from
.dual_chunk_rope
import
DualChunkRotaryEmbedding
from
.dynamic_ntk_alpha_rope
import
DynamicNTKAlphaRotaryEmbedding
from
.dynamic_ntk_scaling_rope
import
DynamicNTKScalingRotaryEmbedding
from
.linear_scaling_rope
import
LinearScalingRotaryEmbedding
from
.llama3_rope
import
Llama3RotaryEmbedding
from
.llama4_vision_rope
import
Llama4VisionRotaryEmbedding
from
.mrope
import
MRotaryEmbedding
from
.ntk_scaling_rope
import
NTKScalingRotaryEmbedding
from
.phi3_long_rope_scaled_rope
import
Phi3LongRoPEScaledRotaryEmbedding
from
.yarn_scaling_rope
import
YaRNScalingRotaryEmbedding
_ROPE_DICT
:
dict
[
tuple
,
RotaryEmbedding
]
=
{}
def
get_rope
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
float
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
if
rope_scaling
is
not
None
:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
rope_scaling
.
items
()
}
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
if
dual_chunk_attention_config
is
not
None
:
dual_chunk_attention_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
!=
"sparse_attention_config"
}
dual_chunk_attention_args
=
tuple
(
dual_chunk_attention_tuple
.
items
())
else
:
dual_chunk_attention_args
=
None
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
,
dual_chunk_attention_args
,
dtype
)
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
if
dual_chunk_attention_config
is
not
None
:
extra_kwargs
=
{
k
:
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
in
(
"chunk_size"
,
"local_size"
)
}
rotary_emb
=
DualChunkRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
**
extra_kwargs
)
elif
not
rope_scaling
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
else
:
scaling_type
=
rope_scaling
[
"rope_type"
]
if
scaling_type
==
"llama3"
:
scaling_factor
=
rope_scaling
[
"factor"
]
low_freq_factor
=
rope_scaling
[
"low_freq_factor"
]
high_freq_factor
=
rope_scaling
[
"high_freq_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
rotary_emb
=
Llama3RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position
)
elif
scaling_type
==
"mllama4"
:
rotary_emb
=
Llama4VisionRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
elif
scaling_type
==
"default"
:
if
"mrope_section"
in
rope_scaling
:
rotary_emb
=
MRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
mrope_section
=
rope_scaling
[
"mrope_section"
],
)
else
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
)
elif
scaling_type
==
"linear"
:
scaling_factor
=
rope_scaling
[
"factor"
]
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
)
elif
scaling_type
==
"ntk"
:
scaling_factor
=
rope_scaling
[
"factor"
]
mixed_b
=
rope_scaling
.
get
(
'mixed_b'
,
None
)
rotary_emb
=
NTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
mixed_b
)
elif
scaling_type
==
"dynamic"
:
if
"alpha"
in
rope_scaling
:
scaling_alpha
=
rope_scaling
[
"alpha"
]
rotary_emb
=
DynamicNTKAlphaRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_alpha
,
dtype
)
elif
"factor"
in
rope_scaling
:
scaling_factor
=
rope_scaling
[
"factor"
]
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
)
else
:
raise
ValueError
(
"Dynamic rope scaling must contain either "
"'alpha' or 'factor' field"
)
elif
scaling_type
==
"yarn"
:
scaling_factor
=
rope_scaling
[
"factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
)
}
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
elif
scaling_type
==
"deepseek_yarn"
:
scaling_factor
=
rope_scaling
[
"factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
# assert max_position == original_max_position * scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"mscale"
,
"mscale_all_dim"
)
}
rotary_emb
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
elif
scaling_type
==
"longrope"
:
short_factor
=
rope_scaling
[
"short_factor"
]
long_factor
=
rope_scaling
[
"long_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
rotary_emb
=
Phi3LongRoPEScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
return
rotary_emb
vllm/model_executor/layers/rotary_embedding/base.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Rotary Positional Embeddings Base Class."""
from
typing
import
Optional
import
torch
from
vllm.model_executor.custom_op
import
CustomOp
from
.common
import
apply_rotary_emb_dispatch
,
apply_rotary_emb_torch
@
CustomOp
.
register
(
"rotary_embedding"
)
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
))
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward_native
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_torch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
# key may be None in some cases, e.g. cross-layer KV sharing
if
key
is
not
None
:
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_torch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
from
vllm
import
_custom_ops
as
ops
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# is expensive, so avoid calling it if possible
if
self
.
cos_sin_cache
.
device
!=
query
.
device
or
\
self
.
cos_sin_cache
.
dtype
!=
query
.
dtype
:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if
offsets
is
not
None
:
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
def
forward_xpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
from
vllm._ipex_ops
import
ipex_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if
key
is
None
:
# XPU kernel doesn't support key=None so fall back to native impl
# TODO(sarckk): add support for optional key in
# ipex.llm.functional.rotary_embedding_batched
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
else
:
if
offsets
is
not
None
:
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
def
forward_neuron
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
def
_apply_rotary_emb_neuron
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
# x1 = x[..., ::2]
# x2 = x[..., 1::2]
d
=
x
.
shape
[
-
1
]
//
2
x_reshaped
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x1
=
x_reshaped
[:,
::
2
].
view
(
*
x
.
shape
[:
-
1
],
d
)
x2
=
x_reshaped
[:,
1
::
2
].
view
(
*
x
.
shape
[:
-
1
],
d
)
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
if
offsets
is
not
None
:
positions
=
positions
+
offsets
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
if
key
is
not
None
:
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
if
self
.
rotary_dim
==
self
.
head_size
:
query
=
apply_rotary_emb_dispatch
(
query
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
query
.
reshape
(
query_shape
)
if
key
is
not
None
:
key
=
apply_rotary_emb_dispatch
(
key
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
key
.
reshape
(
key_shape
)
else
:
head_size
=
query
.
shape
[
-
1
]
query_reshaped
=
query
.
view
(
-
1
,
head_size
)
query_pass
=
query_reshaped
[:,
self
.
rotary_dim
:].
view
(
*
query
.
shape
[:
-
1
],
head_size
-
self
.
rotary_dim
)
query_rot
=
query_reshaped
[:,
:
self
.
rotary_dim
].
view
(
*
query
.
shape
[:
-
1
],
self
.
rotary_dim
)
query_rot
=
_apply_rotary_emb_neuron
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
if
key
is
not
None
:
key_reshaped
=
key
.
view
(
-
1
,
head_size
)
key_pass
=
key_reshaped
[:,
self
.
rotary_dim
:].
view
(
*
key
.
shape
[:
-
1
],
head_size
-
self
.
rotary_dim
)
key_rot
=
key_reshaped
[:,
:
self
.
rotary_dim
].
view
(
*
key
.
shape
[:
-
1
],
self
.
rotary_dim
)
key_rot
=
_apply_rotary_emb_neuron
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
return
s
vllm/model_executor/layers/rotary_embedding/common.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
torch
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
# common functions
def
rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
rotate_gptj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
def
apply_rotary_emb_torch
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
def
apply_rotary_emb_dispatch
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if
current_platform
.
is_cuda
():
return
apply_rotary_emb
(
x
.
unsqueeze
(
0
),
cos
,
sin
,
not
is_neox_style
).
squeeze
(
0
)
else
:
return
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
is_neox_style
)
# yarn functions
# Inverse dim formula to find dim based on number of rotations
def
yarn_find_correction_dim
(
num_rotations
:
int
,
dim
:
int
,
base
:
float
=
10000
,
max_position_embeddings
:
int
=
2048
)
->
float
:
return
(
dim
*
math
.
log
(
max_position_embeddings
/
(
num_rotations
*
2
*
math
.
pi
)))
/
(
2
*
math
.
log
(
base
))
# Find dim range bounds based on rotations
def
yarn_find_correction_range
(
low_rot
:
int
,
high_rot
:
int
,
dim
:
int
,
base
:
float
=
10000
,
max_position_embeddings
:
int
=
2048
)
->
tuple
[
int
,
int
]:
low
=
math
.
floor
(
yarn_find_correction_dim
(
low_rot
,
dim
,
base
,
max_position_embeddings
))
high
=
math
.
ceil
(
yarn_find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
# Clamp values just in case
def
yarn_linear_ramp_mask
(
low
:
float
,
high
:
float
,
dim
:
int
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
if
low
==
high
:
high
+=
0.001
# Prevent singularity
linear_func
=
(
torch
.
arange
(
dim
,
dtype
=
dtype
)
-
low
)
/
(
high
-
low
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
def
yarn_get_mscale
(
scale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
math
.
log
(
scale
)
+
1.0
vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
typing
import
Optional
import
torch
from
vllm.platforms
import
current_platform
from
.base
import
RotaryEmbedding
from
.common
import
(
rotate_gptj
,
rotate_neox
,
yarn_find_correction_range
,
yarn_linear_ramp_mask
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
mscale
:
float
=
1
,
mscale_all_dim
:
float
=
0
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale
))
/
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale_all_dim
))
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
current_platform
.
device_type
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
))
*
self
.
extrapolation_factor
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
device
=
current_platform
.
device_type
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
(
freqs
.
cos
()
*
self
.
mscale
)
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
assert
key
is
not
None
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
if
self
.
cos_sin_cache
.
device
!=
positions
.
device
:
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
rotate_neox
if
self
.
is_neox_style
else
rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
return
query
,
key
vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
from
vllm.model_executor.custom_op
import
CustomOp
from
.common
import
rotate_gptj
,
rotate_neox
@
CustomOp
.
register
(
"dual_chunk_rotary_embedding"
)
class
DualChunkRotaryEmbedding
(
CustomOp
):
"""Rotary positional embedding for Dual Chunk Attention."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
chunk_size
:
int
,
local_size
:
int
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
chunk_size
=
chunk_size
self
.
local_size
=
local_size
self
.
dtype
=
dtype
self
.
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
(
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
)
=
self
.
_compute_cos_sin_cache
()
self
.
register_buffer
(
"cos_sin_q_cache"
,
q_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_cache"
,
qc_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_k_cache"
,
k_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_no_clamp_cache"
,
qc_no_clamp_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_q_inter_cache"
,
q_inter_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
))
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
q_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
qc_t
=
(
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
).
clamp
(
max
=
self
.
chunk_size
)
k_t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
%
chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
self
.
chunk_size
q_freqs
=
torch
.
outer
(
q_t
,
inv_freq
)
qc_freqs
=
torch
.
outer
(
qc_t
,
inv_freq
)
k_freqs
=
torch
.
outer
(
k_t
,
inv_freq
)
qc_no_clamp_freqs
=
torch
.
outer
(
qc_no_clamp_t
,
inv_freq
)
q_inter_freqs
=
torch
.
outer
(
q_inter_t
,
inv_freq
)
q_cos
=
q_freqs
.
cos
()
q_sin
=
q_freqs
.
sin
()
qc_cos
=
qc_freqs
.
cos
()
qc_sin
=
qc_freqs
.
sin
()
k_cos
=
k_freqs
.
cos
()
k_sin
=
k_freqs
.
sin
()
qc_no_clamp_cos
=
qc_no_clamp_freqs
.
cos
()
qc_no_clamp_sin
=
qc_no_clamp_freqs
.
sin
()
q_inter_cos
=
q_inter_freqs
.
cos
()
q_inter_sin
=
q_inter_freqs
.
sin
()
q_cache
=
torch
.
cat
((
q_cos
,
q_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_cache
=
torch
.
cat
((
qc_cos
,
qc_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
k_cache
=
torch
.
cat
((
k_cos
,
k_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_no_clamp_cache
=
torch
.
cat
((
qc_no_clamp_cos
,
qc_no_clamp_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
q_inter_cache
=
torch
.
cat
((
q_inter_cos
,
q_inter_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
return
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
else
:
query_pass
=
None
key_pass
=
None
positions_with_offsets
=
(
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
)
key
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_k_cache
[
positions_with_offsets
],
key_rot
,
key_pass
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
query
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
)
query_succ
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
)
query_inter
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
chunk_len
-
1
].
repeat
(
positions
.
shape
[
0
],
1
),
query_rot
,
query_pass
)
query_succ_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_no_clamp_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
)
query_inter_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_inter_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
)
# merge query into one tensor to simplify the interfaces
query
=
torch
.
cat
((
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
),
dim
=-
1
)
return
query
,
key
def
_apply_rotary_embedding
(
self
,
cos_sin
,
hidden_rot
,
hidden_pass
):
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
rotate_neox
if
self
.
is_neox_style
else
rotate_gptj
hidden_rot
=
hidden_rot
*
cos
+
rotate_fn
(
hidden_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
hidden
=
torch
.
cat
((
hidden_rot
,
hidden_pass
),
dim
=-
1
)
else
:
hidden
=
hidden_rot
return
hidden
.
flatten
(
-
2
).
squeeze
(
0
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
s
+=
f
", chunk_size=
{
self
.
chunk_size
}
, local_size=
{
self
.
local_size
}
"
return
s
vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
.base
import
RotaryEmbedding
class
DynamicNTKAlphaRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with Dynamic NTK alpha.
Based on the original RotaryEmbedding implementation.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_alpha
:
float
,
dtype
:
torch
.
dtype
,
)
->
None
:
self
.
scaling_alpha
=
scaling_alpha
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
# For Hunyuan DynamicNTKAlphaRotaryEmbedding
max_len
=
self
.
max_position_embeddings
base
=
self
.
base
*
self
.
scaling_alpha
**
(
self
.
rotary_dim
/
(
self
.
rotary_dim
-
2
))
inv_freq
=
self
.
_compute_inv_freq
(
base
)
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
.base
import
RotaryEmbedding
class
DynamicNTKScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len
=
self
.
max_position_embeddings
*
self
.
scaling_factor
base
=
self
.
base
*
(
(
self
.
scaling_factor
*
max_len
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
))
**
(
self
.
rotary_dim
/
(
self
.
rotary_dim
-
2
))
inv_freq
=
self
.
_compute_inv_freq
(
base
)
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Union
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
.base
import
RotaryEmbedding
class
LinearScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with linear scaling.
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per lora, we can run multiple
lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
Exemplary for two scaling factors x=1, y and z with embeddings
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
we construct the cos/sin cache as follows:
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
...
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
We then use offsets to index into the cos/sin cache for
the respective scaling factors.
The offset to cache can be accessed via `scaling_factor_to_offset` API.
Credits to the Reddit user /u/kaiokendev
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factors
:
Union
[
list
[
float
],
float
],
dtype
:
torch
.
dtype
,
)
->
None
:
if
isinstance
(
scaling_factors
,
float
):
scaling_factors
=
[
scaling_factors
]
self
.
scaling_factors
:
list
[
float
]
=
scaling_factors
# noqa
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
# Lazy initialized.
self
.
_scaling_factor_to_offset
:
dict
[
float
,
int
]
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
cache_list
:
list
[
torch
.
Tensor
]
=
[]
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets
:
list
[
int
]
=
[]
for
scaling_factor
in
self
.
scaling_factors
:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len
=
self
.
max_position_embeddings
*
scaling_factor
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
t
=
t
/
scaling_factor
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
if
not
cache_list
:
offset
=
0
else
:
last_offset
=
offsets
[
-
1
]
next_max_len
=
cache_list
[
-
1
].
shape
[
0
]
offset
=
last_offset
+
next_max_len
offsets
.
append
(
offset
)
cache_list
.
append
(
cache
)
self
.
_scaling_factor_to_offset
=
{
float
(
scaling_factor
):
offsets
[
i
]
for
i
,
scaling_factor
in
enumerate
(
self
.
scaling_factors
)
}
assert
len
(
self
.
scaling_factors
)
==
len
(
offsets
)
return
torch
.
cat
(
cache_list
,
dim
=
0
)
@
property
def
scaling_factor_to_offset
(
self
)
->
dict
[
float
,
int
]:
return
self
.
_scaling_factor_to_offset
vllm/model_executor/layers/rotary_embedding/llama3_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
torch
from
.base
import
RotaryEmbedding
class
Llama3RotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
scaling_factor
:
float
,
low_freq_factor
:
float
,
high_freq_factor
:
float
,
orig_max_position
:
int
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
low_freq_factor
=
low_freq_factor
self
.
high_freq_factor
=
high_freq_factor
self
.
orig_max_position
=
orig_max_position
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
low_freq_wavelen
=
self
.
orig_max_position
/
self
.
low_freq_factor
high_freq_wavelen
=
self
.
orig_max_position
/
self
.
high_freq_factor
wave_len
=
2
*
math
.
pi
/
inv_freqs
if
self
.
low_freq_factor
!=
self
.
high_freq_factor
:
smooth
=
(
self
.
orig_max_position
/
wave_len
-
self
.
low_freq_factor
)
/
(
self
.
high_freq_factor
-
self
.
low_freq_factor
)
else
:
smooth
=
0
new_freqs
=
torch
.
where
(
wave_len
<
high_freq_wavelen
,
inv_freqs
,
torch
.
where
(
wave_len
>
low_freq_wavelen
,
inv_freqs
/
self
.
scaling_factor
,
(
1
-
smooth
)
*
inv_freqs
/
self
.
scaling_factor
+
smooth
*
inv_freqs
,
),
)
return
new_freqs
vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
typing
import
Optional
import
torch
from
.base
import
RotaryEmbedding
class
Llama4VisionRotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
):
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
inv_freqs
=
inv_freqs
[:(
self
.
rotary_dim
//
2
)]
return
inv_freqs
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches
=
self
.
max_position_embeddings
img_idx
=
torch
.
arange
(
num_patches
,
dtype
=
torch
.
int32
)
\
.
reshape
(
num_patches
,
1
)
img_idx
=
torch
.
cat
([
img_idx
,
img_idx
[:
1
]],
dim
=
0
)
img_idx
[
-
1
,
-
1
]
=
-
2
# set to ID_CLS_TOKEN
num_patches_single_dim
=
int
(
math
.
sqrt
(
num_patches
))
frequencies_x
=
img_idx
%
num_patches_single_dim
frequencies_y
=
img_idx
//
num_patches_single_dim
freqs_x
=
((
frequencies_x
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs_y
=
((
frequencies_y
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs
=
torch
.
cat
([
freqs_x
,
freqs_y
],
dim
=-
1
).
float
().
contiguous
()[...,
::
2
]
freqs
=
freqs
.
masked_fill
(
img_idx
.
reshape
(
-
1
,
1
,
1
)
<
0
,
0
)
cache
=
torch
.
view_as_complex
(
torch
.
stack
([
torch
.
cos
(
freqs
),
torch
.
sin
(
freqs
)],
dim
=-
1
))
return
cache
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
assert
key
is
not
None
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
query
.
device
)
query_
=
torch
.
view_as_complex
(
query
.
float
().
reshape
(
*
query
.
shape
[:
-
1
],
-
1
,
2
))
key_
=
torch
.
view_as_complex
(
key
.
float
().
reshape
(
*
key
.
shape
[:
-
1
],
-
1
,
2
))
broadcast_shape
=
[
d
if
i
==
1
or
i
==
(
query_
.
ndim
-
1
)
else
1
for
i
,
d
in
enumerate
(
query_
.
shape
)
]
freqs_ci
=
self
.
cos_sin_cache
.
view
(
*
broadcast_shape
)
query_out
=
torch
.
view_as_real
(
query_
*
freqs_ci
).
flatten
(
3
)
key_out
=
torch
.
view_as_real
(
key_
*
freqs_ci
).
flatten
(
3
)
return
query_out
.
type_as
(
query
),
key_out
.
type_as
(
key
)
vllm/model_executor/layers/rotary_embedding.py
→
vllm/model_executor/layers/rotary_embedding
/mrope
.py
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
import
itertools
import
itertools
import
math
from
typing
import
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.model_executor.custom_op
import
CustomOp
from
.base
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
.common
import
apply_rotary_emb_dispatch
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
_rotate_gptj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
def
_apply_rotary_emb_torch
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if
current_platform
.
is_cuda
():
return
apply_rotary_emb
(
x
.
unsqueeze
(
0
),
cos
,
sin
,
not
is_neox_style
).
squeeze
(
0
)
else
:
return
_apply_rotary_emb_torch
(
x
,
cos
,
sin
,
is_neox_style
)
@
CustomOp
.
register
(
"rotary_embedding"
)
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
))
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward_native
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb_torch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
# key may be None in some cases, e.g. cross-layer KV sharing
if
key
is
not
None
:
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb_torch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
from
vllm
import
_custom_ops
as
ops
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# is expensive, so avoid calling it if possible
if
self
.
cos_sin_cache
.
device
!=
query
.
device
or
\
self
.
cos_sin_cache
.
dtype
!=
query
.
dtype
:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if
offsets
is
not
None
:
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
def
forward_xpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
from
vllm._ipex_ops
import
ipex_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if
key
is
None
:
# XPU kernel doesn't support key=None so fall back to native impl
# TODO(sarckk): add support for optional key in
# ipex.llm.functional.rotary_embedding_batched
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
else
:
if
offsets
is
not
None
:
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
def
forward_neuron
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
def
_apply_rotary_emb_neuron
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
# x1 = x[..., ::2]
# x2 = x[..., 1::2]
d
=
x
.
shape
[
-
1
]
//
2
x_reshaped
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x1
=
x_reshaped
[:,
::
2
].
view
(
*
x
.
shape
[:
-
1
],
d
)
x2
=
x_reshaped
[:,
1
::
2
].
view
(
*
x
.
shape
[:
-
1
],
d
)
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
if
offsets
is
not
None
:
positions
=
positions
+
offsets
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
if
key
is
not
None
:
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
if
self
.
rotary_dim
==
self
.
head_size
:
query
=
_apply_rotary_emb
(
query
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
query
.
reshape
(
query_shape
)
if
key
is
not
None
:
key
=
_apply_rotary_emb
(
key
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
key
.
reshape
(
key_shape
)
else
:
head_size
=
query
.
shape
[
-
1
]
query_reshaped
=
query
.
view
(
-
1
,
head_size
)
query_pass
=
query_reshaped
[:,
self
.
rotary_dim
:].
view
(
*
query
.
shape
[:
-
1
],
head_size
-
self
.
rotary_dim
)
query_rot
=
query_reshaped
[:,
:
self
.
rotary_dim
].
view
(
*
query
.
shape
[:
-
1
],
self
.
rotary_dim
)
query_rot
=
_apply_rotary_emb_neuron
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
if
key
is
not
None
:
key_reshaped
=
key
.
view
(
-
1
,
head_size
)
key_pass
=
key_reshaped
[:,
self
.
rotary_dim
:].
view
(
*
key
.
shape
[:
-
1
],
head_size
-
self
.
rotary_dim
)
key_rot
=
key_reshaped
[:,
:
self
.
rotary_dim
].
view
(
*
key
.
shape
[:
-
1
],
self
.
rotary_dim
)
key_rot
=
_apply_rotary_emb_neuron
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
return
s
class
LinearScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with linear scaling.
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per lora, we can run multiple
lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
Exemplary for two scaling factors x=1, y and z with embeddings
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
we construct the cos/sin cache as follows:
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
...
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
We then use offsets to index into the cos/sin cache for
the respective scaling factors.
The offset to cache can be accessed via `scaling_factor_to_offset` API.
Credits to the Reddit user /u/kaiokendev
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factors
:
Union
[
list
[
float
],
float
],
dtype
:
torch
.
dtype
,
)
->
None
:
if
isinstance
(
scaling_factors
,
float
):
scaling_factors
=
[
scaling_factors
]
self
.
scaling_factors
:
list
[
float
]
=
scaling_factors
# noqa
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
# Lazy initialized.
self
.
_scaling_factor_to_offset
:
dict
[
float
,
int
]
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
cache_list
:
list
[
torch
.
Tensor
]
=
[]
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets
:
list
[
int
]
=
[]
for
scaling_factor
in
self
.
scaling_factors
:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len
=
self
.
max_position_embeddings
*
scaling_factor
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
t
=
t
/
scaling_factor
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
if
not
cache_list
:
offset
=
0
else
:
last_offset
=
offsets
[
-
1
]
next_max_len
=
cache_list
[
-
1
].
shape
[
0
]
offset
=
last_offset
+
next_max_len
offsets
.
append
(
offset
)
cache_list
.
append
(
cache
)
self
.
_scaling_factor_to_offset
=
{
float
(
scaling_factor
):
offsets
[
i
]
for
i
,
scaling_factor
in
enumerate
(
self
.
scaling_factors
)
}
assert
len
(
self
.
scaling_factors
)
==
len
(
offsets
)
return
torch
.
cat
(
cache_list
,
dim
=
0
)
@
property
def
scaling_factor_to_offset
(
self
)
->
dict
[
float
,
int
]:
return
self
.
_scaling_factor_to_offset
class
NTKScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with fixed and mixed NTK scaling.
https://kexue.fm/archives/9706 """
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
mixed_b
:
Optional
[
float
]
=
None
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
mixed_b
=
mixed_b
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
base
=
self
.
base
*
(
self
.
scaling_factor
if
self
.
mixed_b
is
None
else
1
)
inv_freq
=
super
().
_compute_inv_freq
(
base
)
if
self
.
mixed_b
is
None
:
inv_freq
=
inv_freq
/
self
.
scaling_factor
**
(
2
/
self
.
rotary_dim
)
else
:
a
=
torch
.
tensor
(
self
.
scaling_factor
).
log
()
/
(
self
.
rotary_dim
/
2
)
**
self
.
mixed_b
lambda_1_m
=
(
a
*
torch
.
arange
(
1
,
self
.
rotary_dim
//
2
+
1
).
float
()
**
self
.
mixed_b
).
exp
()
inv_freq
=
inv_freq
/
lambda_1_m
return
inv_freq
class
DynamicNTKScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len
=
self
.
max_position_embeddings
*
self
.
scaling_factor
base
=
self
.
base
*
(
(
self
.
scaling_factor
*
max_len
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
))
**
(
self
.
rotary_dim
/
(
self
.
rotary_dim
-
2
))
inv_freq
=
self
.
_compute_inv_freq
(
base
)
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
class
DynamicNTKAlphaRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with Dynamic NTK alpha.
Based on the original RotaryEmbedding implementation.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_alpha
:
float
,
dtype
:
torch
.
dtype
,
)
->
None
:
self
.
scaling_alpha
=
scaling_alpha
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
# For Hunyuan DynamicNTKAlphaRotaryEmbedding
max_len
=
self
.
max_position_embeddings
base
=
self
.
base
*
self
.
scaling_alpha
**
(
self
.
rotary_dim
/
(
self
.
rotary_dim
-
2
))
inv_freq
=
self
.
_compute_inv_freq
(
base
)
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
# Inverse dim formula to find dim based on number of rotations
def
_yarn_find_correction_dim
(
num_rotations
:
int
,
dim
:
int
,
base
:
float
=
10000
,
max_position_embeddings
:
int
=
2048
)
->
float
:
return
(
dim
*
math
.
log
(
max_position_embeddings
/
(
num_rotations
*
2
*
math
.
pi
)))
/
(
2
*
math
.
log
(
base
))
# Find dim range bounds based on rotations
def
_yarn_find_correction_range
(
low_rot
:
int
,
high_rot
:
int
,
dim
:
int
,
base
:
float
=
10000
,
max_position_embeddings
:
int
=
2048
)
->
tuple
[
int
,
int
]:
low
=
math
.
floor
(
_yarn_find_correction_dim
(
low_rot
,
dim
,
base
,
max_position_embeddings
))
high
=
math
.
ceil
(
_yarn_find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
# Clamp values just in case
def
_yarn_linear_ramp_mask
(
low
:
float
,
high
:
float
,
dim
:
int
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
if
low
==
high
:
high
+=
0.001
# Prevent singularity
linear_func
=
(
torch
.
arange
(
dim
,
dtype
=
dtype
)
-
low
)
/
(
high
-
low
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
def
_yarn_get_mscale
(
scale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
math
.
log
(
scale
)
+
1.0
class
YaRNScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation
self
.
mscale
=
float
(
_yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
_yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
_yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
))
*
self
.
extrapolation_factor
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
(
freqs
.
cos
()
*
self
.
mscale
)
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
class
Phi3LongRoPEScaledRotaryEmbedding
(
nn
.
Module
):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
original_max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
short_factor
:
list
[
float
],
long_factor
:
list
[
float
],
short_mscale
:
Optional
[
float
]
=
None
,
long_mscale
:
Optional
[
float
]
=
None
,
):
super
().
__init__
()
if
is_neox_style
is
False
:
raise
ValueError
(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self
.
rotary_dim
=
rotary_dim
self
.
head_size
=
head_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
original_max_position_embeddings
=
original_max_position_embeddings
self
.
base
=
base
self
.
short_factor
=
short_factor
self
.
long_factor
=
long_factor
scale
=
self
.
max_position_embeddings
/
\
self
.
original_max_position_embeddings
if
scale
<=
1.0
:
scaling_factor
=
1.0
else
:
scaling_factor
=
math
.
sqrt
(
1
+
math
.
log
(
scale
)
/
math
.
log
(
self
.
original_max_position_embeddings
))
if
short_mscale
is
None
:
short_mscale
=
scaling_factor
if
long_mscale
is
None
:
long_mscale
=
scaling_factor
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
dtype
)
long_cache
=
self
.
_compute_cos_sin_cache
(
max_position_embeddings
,
long_factor
,
long_mscale
)
long_cache
=
long_cache
.
to
(
dtype
)
long_short_cache
=
torch
.
cat
([
short_cache
,
long_cache
],
dim
=
0
)
self
.
register_buffer
(
"long_short_cos_sin_cache"
,
long_short_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
rescale_factors
:
list
[
float
])
->
torch
.
Tensor
:
rescale_factors
=
torch
.
tensor
(
rescale_factors
,
dtype
=
torch
.
float32
)
inv_freq
=
1.0
/
(
rescale_factors
*
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)))
return
inv_freq
def
_compute_cos_sin_cache
(
self
,
max_position_embeddings
:
int
,
rescale_factors
:
list
[
float
],
mscale
:
float
,
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
mscale
sin
=
freqs
.
sin
()
*
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
assert
key
is
not
None
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
k
=
self
.
original_max_position_embeddings
long_prompt_offset
=
(
torch
.
any
(
positions
>
k
).
float
()
*
torch
.
full_like
(
positions
,
k
)).
long
()
idx
=
(
torch
.
add
(
positions
,
long_prompt_offset
)
if
long_prompt_offset
is
not
None
else
positions
)
idx
=
torch
.
add
(
idx
,
offsets
)
if
offsets
is
not
None
else
idx
cos_sin
=
torch
.
index_select
(
self
.
long_short_cos_sin_cache
,
0
,
idx
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
=
cos
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
query_rot
*
cos
+
_rotate_neox
(
query_rot
)
*
sin
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
key_rot
*
cos
+
_rotate_neox
(
key_rot
)
*
sin
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
return
query
.
flatten
(
-
2
),
key
.
flatten
(
-
2
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
mscale
:
float
=
1
,
mscale_all_dim
:
float
=
0
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale
))
/
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale_all_dim
))
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
current_platform
.
device_type
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
_yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
_yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
))
*
self
.
extrapolation_factor
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
device
=
current_platform
.
device_type
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
(
freqs
.
cos
()
*
self
.
mscale
)
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
assert
key
is
not
None
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
if
self
.
cos_sin_cache
.
device
!=
positions
.
device
:
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
return
query
,
key
class
Llama3RotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
scaling_factor
:
float
,
low_freq_factor
:
float
,
high_freq_factor
:
float
,
orig_max_position
:
int
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
low_freq_factor
=
low_freq_factor
self
.
high_freq_factor
=
high_freq_factor
self
.
orig_max_position
=
orig_max_position
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
low_freq_wavelen
=
self
.
orig_max_position
/
self
.
low_freq_factor
high_freq_wavelen
=
self
.
orig_max_position
/
self
.
high_freq_factor
wave_len
=
2
*
math
.
pi
/
inv_freqs
if
self
.
low_freq_factor
!=
self
.
high_freq_factor
:
smooth
=
(
self
.
orig_max_position
/
wave_len
-
self
.
low_freq_factor
)
/
(
self
.
high_freq_factor
-
self
.
low_freq_factor
)
else
:
smooth
=
0
new_freqs
=
torch
.
where
(
wave_len
<
high_freq_wavelen
,
inv_freqs
,
torch
.
where
(
wave_len
>
low_freq_wavelen
,
inv_freqs
/
self
.
scaling_factor
,
(
1
-
smooth
)
*
inv_freqs
/
self
.
scaling_factor
+
smooth
*
inv_freqs
,
),
)
return
new_freqs
class
Llama4VisionRotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
):
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
inv_freqs
=
inv_freqs
[:(
self
.
rotary_dim
//
2
)]
return
inv_freqs
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches
=
self
.
max_position_embeddings
img_idx
=
torch
.
arange
(
num_patches
,
dtype
=
torch
.
int32
)
\
.
reshape
(
num_patches
,
1
)
img_idx
=
torch
.
cat
([
img_idx
,
img_idx
[:
1
]],
dim
=
0
)
img_idx
[
-
1
,
-
1
]
=
-
2
# set to ID_CLS_TOKEN
num_patches_single_dim
=
int
(
math
.
sqrt
(
num_patches
))
frequencies_x
=
img_idx
%
num_patches_single_dim
frequencies_y
=
img_idx
//
num_patches_single_dim
freqs_x
=
((
frequencies_x
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs_y
=
((
frequencies_y
+
1
)[...,
None
]
*
inv_freq
[
None
,
None
,
:]).
repeat_interleave
(
2
,
dim
=-
1
)
freqs
=
torch
.
cat
([
freqs_x
,
freqs_y
],
dim
=-
1
).
float
().
contiguous
()[...,
::
2
]
freqs
=
freqs
.
masked_fill
(
img_idx
.
reshape
(
-
1
,
1
,
1
)
<
0
,
0
)
cache
=
torch
.
view_as_complex
(
torch
.
stack
([
torch
.
cos
(
freqs
),
torch
.
sin
(
freqs
)],
dim
=-
1
))
return
cache
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
assert
key
is
not
None
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
query
.
device
)
query_
=
torch
.
view_as_complex
(
query
.
float
().
reshape
(
*
query
.
shape
[:
-
1
],
-
1
,
2
))
key_
=
torch
.
view_as_complex
(
key
.
float
().
reshape
(
*
key
.
shape
[:
-
1
],
-
1
,
2
))
broadcast_shape
=
[
d
if
i
==
1
or
i
==
(
query_
.
ndim
-
1
)
else
1
for
i
,
d
in
enumerate
(
query_
.
shape
)
]
freqs_ci
=
self
.
cos_sin_cache
.
view
(
*
broadcast_shape
)
query_out
=
torch
.
view_as_real
(
query_
*
freqs_ci
).
flatten
(
3
)
key_out
=
torch
.
view_as_real
(
key_
*
freqs_ci
).
flatten
(
3
)
return
query_out
.
type_as
(
query
),
key_out
.
type_as
(
key
)
class
MRotaryEmbedding
(
RotaryEmbedding
):
class
MRotaryEmbedding
(
RotaryEmbedding
):
...
@@ -1024,14 +75,16 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1024,14 +75,16 @@ class MRotaryEmbedding(RotaryEmbedding):
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
return
query
,
key
...
@@ -1615,353 +668,3 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1615,353 +668,3 @@ class MRotaryEmbedding(RotaryEmbedding):
updates
.
extend
([
audio_end_token_id
])
updates
.
extend
([
audio_end_token_id
])
return
updates
return
updates
@
CustomOp
.
register
(
"dual_chunk_rotary_embedding"
)
class
DualChunkRotaryEmbedding
(
CustomOp
):
"""Rotary positional embedding for Dual Chunk Attention."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
chunk_size
:
int
,
local_size
:
int
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
chunk_size
=
chunk_size
self
.
local_size
=
local_size
self
.
dtype
=
dtype
self
.
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
(
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
)
=
self
.
_compute_cos_sin_cache
()
self
.
register_buffer
(
"cos_sin_q_cache"
,
q_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_cache"
,
qc_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_k_cache"
,
k_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_qc_no_clamp_cache"
,
qc_no_clamp_cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_q_inter_cache"
,
q_inter_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
))
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
q_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
qc_t
=
(
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
).
clamp
(
max
=
self
.
chunk_size
)
k_t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
%
chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t
=
torch
.
arange
(
chunk_len
,
dtype
=
torch
.
float
)
+
self
.
chunk_size
q_freqs
=
torch
.
outer
(
q_t
,
inv_freq
)
qc_freqs
=
torch
.
outer
(
qc_t
,
inv_freq
)
k_freqs
=
torch
.
outer
(
k_t
,
inv_freq
)
qc_no_clamp_freqs
=
torch
.
outer
(
qc_no_clamp_t
,
inv_freq
)
q_inter_freqs
=
torch
.
outer
(
q_inter_t
,
inv_freq
)
q_cos
=
q_freqs
.
cos
()
q_sin
=
q_freqs
.
sin
()
qc_cos
=
qc_freqs
.
cos
()
qc_sin
=
qc_freqs
.
sin
()
k_cos
=
k_freqs
.
cos
()
k_sin
=
k_freqs
.
sin
()
qc_no_clamp_cos
=
qc_no_clamp_freqs
.
cos
()
qc_no_clamp_sin
=
qc_no_clamp_freqs
.
sin
()
q_inter_cos
=
q_inter_freqs
.
cos
()
q_inter_sin
=
q_inter_freqs
.
sin
()
q_cache
=
torch
.
cat
((
q_cos
,
q_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_cache
=
torch
.
cat
((
qc_cos
,
qc_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
k_cache
=
torch
.
cat
((
k_cos
,
k_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
qc_no_clamp_cache
=
torch
.
cat
((
qc_no_clamp_cos
,
qc_no_clamp_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
q_inter_cache
=
torch
.
cat
((
q_inter_cos
,
q_inter_sin
),
dim
=-
1
).
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
return
q_cache
,
qc_cache
,
k_cache
,
qc_no_clamp_cache
,
q_inter_cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
else
:
query_pass
=
None
key_pass
=
None
positions_with_offsets
=
(
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
)
key
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_k_cache
[
positions_with_offsets
],
key_rot
,
key_pass
)
chunk_len
=
self
.
chunk_size
-
self
.
local_size
query
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
)
query_succ
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
)
query_inter
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_cache
[
chunk_len
-
1
].
repeat
(
positions
.
shape
[
0
],
1
),
query_rot
,
query_pass
)
query_succ_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_qc_no_clamp_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
)
query_inter_critical
=
self
.
_apply_rotary_embedding
(
self
.
cos_sin_q_inter_cache
[
positions_with_offsets
%
chunk_len
],
query_rot
,
query_pass
)
# merge query into one tensor to simplify the interfaces
query
=
torch
.
cat
((
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
),
dim
=-
1
)
return
query
,
key
def
_apply_rotary_embedding
(
self
,
cos_sin
,
hidden_rot
,
hidden_pass
):
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
hidden_rot
=
hidden_rot
*
cos
+
rotate_fn
(
hidden_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
hidden
=
torch
.
cat
((
hidden_rot
,
hidden_pass
),
dim
=-
1
)
else
:
hidden
=
hidden_rot
return
hidden
.
flatten
(
-
2
).
squeeze
(
0
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
s
+=
f
", chunk_size=
{
self
.
chunk_size
}
, local_size=
{
self
.
local_size
}
"
return
s
_ROPE_DICT
:
dict
[
tuple
,
RotaryEmbedding
]
=
{}
def
get_rope
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
float
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
if
rope_scaling
is
not
None
:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
rope_scaling
.
items
()
}
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
if
dual_chunk_attention_config
is
not
None
:
dual_chunk_attention_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
!=
"sparse_attention_config"
}
dual_chunk_attention_args
=
tuple
(
dual_chunk_attention_tuple
.
items
())
else
:
dual_chunk_attention_args
=
None
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
,
dual_chunk_attention_args
,
dtype
)
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
if
dual_chunk_attention_config
is
not
None
:
extra_kwargs
=
{
k
:
v
for
k
,
v
in
dual_chunk_attention_config
.
items
()
if
k
in
(
"chunk_size"
,
"local_size"
)
}
rotary_emb
=
DualChunkRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
**
extra_kwargs
)
elif
not
rope_scaling
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
else
:
scaling_type
=
rope_scaling
[
"rope_type"
]
if
scaling_type
==
"llama3"
:
scaling_factor
=
rope_scaling
[
"factor"
]
low_freq_factor
=
rope_scaling
[
"low_freq_factor"
]
high_freq_factor
=
rope_scaling
[
"high_freq_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
rotary_emb
=
Llama3RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position
)
elif
scaling_type
==
"mllama4"
:
rotary_emb
=
Llama4VisionRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
elif
scaling_type
==
"default"
:
if
"mrope_section"
in
rope_scaling
:
rotary_emb
=
MRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
mrope_section
=
rope_scaling
[
"mrope_section"
],
)
else
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
)
elif
scaling_type
==
"linear"
:
scaling_factor
=
rope_scaling
[
"factor"
]
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
)
elif
scaling_type
==
"ntk"
:
scaling_factor
=
rope_scaling
[
"factor"
]
mixed_b
=
rope_scaling
.
get
(
'mixed_b'
,
None
)
rotary_emb
=
NTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
mixed_b
)
elif
scaling_type
==
"dynamic"
:
if
"alpha"
in
rope_scaling
:
scaling_alpha
=
rope_scaling
[
"alpha"
]
rotary_emb
=
DynamicNTKAlphaRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_alpha
,
dtype
)
elif
"factor"
in
rope_scaling
:
scaling_factor
=
rope_scaling
[
"factor"
]
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
)
else
:
raise
ValueError
(
"Dynamic rope scaling must contain either "
"'alpha' or 'factor' field"
)
elif
scaling_type
==
"yarn"
:
scaling_factor
=
rope_scaling
[
"factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
)
}
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
elif
scaling_type
==
"deepseek_yarn"
:
scaling_factor
=
rope_scaling
[
"factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
# assert max_position == original_max_position * scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"mscale"
,
"mscale_all_dim"
)
}
rotary_emb
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
elif
scaling_type
==
"longrope"
:
short_factor
=
rope_scaling
[
"short_factor"
]
long_factor
=
rope_scaling
[
"long_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
rotary_emb
=
Phi3LongRoPEScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
return
rotary_emb
vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
from
.base
import
RotaryEmbedding
class
NTKScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with fixed and mixed NTK scaling.
https://kexue.fm/archives/9706 """
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
mixed_b
:
Optional
[
float
]
=
None
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
mixed_b
=
mixed_b
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
base
=
self
.
base
*
(
self
.
scaling_factor
if
self
.
mixed_b
is
None
else
1
)
inv_freq
=
super
().
_compute_inv_freq
(
base
)
if
self
.
mixed_b
is
None
:
inv_freq
=
inv_freq
/
self
.
scaling_factor
**
(
2
/
self
.
rotary_dim
)
else
:
a
=
torch
.
tensor
(
self
.
scaling_factor
).
log
()
/
(
self
.
rotary_dim
/
2
)
**
self
.
mixed_b
lambda_1_m
=
(
a
*
torch
.
arange
(
1
,
self
.
rotary_dim
//
2
+
1
).
float
()
**
self
.
mixed_b
).
exp
()
inv_freq
=
inv_freq
/
lambda_1_m
return
inv_freq
vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
.common
import
rotate_neox
class
Phi3LongRoPEScaledRotaryEmbedding
(
nn
.
Module
):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
original_max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
short_factor
:
list
[
float
],
long_factor
:
list
[
float
],
short_mscale
:
Optional
[
float
]
=
None
,
long_mscale
:
Optional
[
float
]
=
None
,
):
super
().
__init__
()
if
is_neox_style
is
False
:
raise
ValueError
(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self
.
rotary_dim
=
rotary_dim
self
.
head_size
=
head_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
original_max_position_embeddings
=
original_max_position_embeddings
self
.
base
=
base
self
.
short_factor
=
short_factor
self
.
long_factor
=
long_factor
scale
=
self
.
max_position_embeddings
/
\
self
.
original_max_position_embeddings
if
scale
<=
1.0
:
scaling_factor
=
1.0
else
:
scaling_factor
=
math
.
sqrt
(
1
+
math
.
log
(
scale
)
/
math
.
log
(
self
.
original_max_position_embeddings
))
if
short_mscale
is
None
:
short_mscale
=
scaling_factor
if
long_mscale
is
None
:
long_mscale
=
scaling_factor
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
dtype
)
long_cache
=
self
.
_compute_cos_sin_cache
(
max_position_embeddings
,
long_factor
,
long_mscale
)
long_cache
=
long_cache
.
to
(
dtype
)
long_short_cache
=
torch
.
cat
([
short_cache
,
long_cache
],
dim
=
0
)
self
.
register_buffer
(
"long_short_cos_sin_cache"
,
long_short_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
rescale_factors
:
list
[
float
])
->
torch
.
Tensor
:
rescale_factors
=
torch
.
tensor
(
rescale_factors
,
dtype
=
torch
.
float32
)
inv_freq
=
1.0
/
(
rescale_factors
*
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)))
return
inv_freq
def
_compute_cos_sin_cache
(
self
,
max_position_embeddings
:
int
,
rescale_factors
:
list
[
float
],
mscale
:
float
,
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
mscale
sin
=
freqs
.
sin
()
*
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
assert
key
is
not
None
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
k
=
self
.
original_max_position_embeddings
long_prompt_offset
=
(
torch
.
any
(
positions
>
k
).
float
()
*
torch
.
full_like
(
positions
,
k
)).
long
()
idx
=
(
torch
.
add
(
positions
,
long_prompt_offset
)
if
long_prompt_offset
is
not
None
else
positions
)
idx
=
torch
.
add
(
idx
,
offsets
)
if
offsets
is
not
None
else
idx
cos_sin
=
torch
.
index_select
(
self
.
long_short_cos_sin_cache
,
0
,
idx
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
=
cos
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
query_rot
*
cos
+
rotate_neox
(
query_rot
)
*
sin
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
key_rot
*
cos
+
rotate_neox
(
key_rot
)
*
sin
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
return
query
.
flatten
(
-
2
),
key
.
flatten
(
-
2
)
vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py
0 → 100644
View file @
6ad6b8e1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
.base
import
RotaryEmbedding
from
.common
import
(
yarn_find_correction_range
,
yarn_get_mscale
,
yarn_linear_ramp_mask
)
class
YaRNScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
))
*
self
.
extrapolation_factor
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
(
freqs
.
cos
()
*
self
.
mscale
)
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment