Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
054072be
Unverified
Commit
054072be
authored
Nov 12, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 12, 2023
Browse files
[Minor] Move RoPE selection logic to `get_rope` (#1633)
parent
eb825c1e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
34 deletions
+47
-34
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+3
-33
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+44
-1
No files found.
vllm/model_executor/layers/attention.py
View file @
054072be
...
@@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
...
@@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from
vllm
import
attention_ops
from
vllm
import
attention_ops
from
vllm
import
cache_ops
from
vllm
import
cache_ops
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.rotary_embedding
import
(
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
DynamicNTKScalingRotaryEmbedding
,
LinearScalingRotaryEmbedding
,
RotaryEmbedding
,
YaRNScalingRotaryEmbedding
)
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
...
@@ -319,36 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -319,36 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention):
scale
,
scale
,
num_kv_heads
,
num_kv_heads
,
sliding_window
=
sliding_window
)
sliding_window
=
sliding_window
)
if
rope_scaling
is
None
:
self
.
rotary_emb
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
self
.
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
is_neox_style
,
rope_scaling
)
max_position
,
base
,
is_neox_style
)
else
:
scaling_type
=
rope_scaling
[
"type"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
self
.
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
)
elif
scaling_type
==
"dynamic"
:
self
.
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
)
elif
scaling_type
==
"yarn"
:
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
assert
max_position
==
original_max_position
*
scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
)
}
self
.
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
**
extra_kwargs
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
054072be
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
# limitations under the License.
# limitations under the License.
"""Rotary Positional Embeddings."""
"""Rotary Positional Embeddings."""
import
math
import
math
from
typing
import
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -271,3 +271,46 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -271,3 +271,46 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
return
cache
def
get_rope
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
int
,
is_neox_style
:
bool
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]],
)
->
RotaryEmbedding
:
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
else
:
scaling_type
=
rope_scaling
[
"type"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
)
elif
scaling_type
==
"dynamic"
:
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
)
elif
scaling_type
==
"yarn"
:
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
assert
max_position
==
original_max_position
*
scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
)
}
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
**
extra_kwargs
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
return
rotary_emb
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment