Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9333fb8e
Unverified
Commit
9333fb8e
authored
Jun 17, 2024
by
Amit Garg
Committed by
GitHub
Jun 17, 2024
Browse files
[Model] Rename Phi3 rope scaling type (#5595)
parent
e2b85cf8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
8 deletions
+16
-8
vllm/config.py
vllm/config.py
+4
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+12
-7
No files found.
vllm/config.py
View file @
9333fb8e
...
...
@@ -1287,7 +1287,10 @@ def _get_and_verify_max_len(
derived_max_model_len
=
default_max_len
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
:
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
\
and
rope_scaling
[
"type"
]
!=
"longrope"
:
if
disable_sliding_window
:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
9333fb8e
...
...
@@ -467,7 +467,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return
cache
class
Phi3
Su
ScaledRotaryEmbedding
(
nn
.
Module
):
class
Phi3
LongRoPE
ScaledRotaryEmbedding
(
nn
.
Module
):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
...
...
@@ -491,11 +491,12 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
if
rotary_dim
!=
head_size
:
raise
ValueError
(
f
"`Phi3
Su
ScaledRotaryEmbedding` does not support
rotary_dim !=
\
head_size (
{
rotary_dim
}
!=
{
head_size
}
)."
)
f
"`Phi3
LongRoPE
ScaledRotaryEmbedding` does not support
\
rotary_dim !=
head_size (
{
rotary_dim
}
!=
{
head_size
}
)."
)
if
is_neox_style
is
False
:
raise
ValueError
(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style."
)
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self
.
head_size
=
head_size
self
.
max_position_embeddings
=
max_position_embeddings
...
...
@@ -608,7 +609,9 @@ def get_rope(
is_neox_style
,
dtype
)
else
:
scaling_type
=
rope_scaling
[
"type"
]
if
scaling_type
!=
"su"
:
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if
scaling_type
!=
"su"
and
scaling_type
!=
"longrope"
:
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
...
...
@@ -633,7 +636,9 @@ def get_rope(
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
elif
scaling_type
==
"su"
:
# The correct one should be "longrope" but keep "su" here
# for backward compatible
elif
scaling_type
==
"su"
or
scaling_type
==
"longrope"
:
short_factor
=
rope_scaling
[
"short_factor"
]
long_factor
=
rope_scaling
[
"long_factor"
]
original_max_position
=
rope_scaling
[
...
...
@@ -643,7 +648,7 @@ def get_rope(
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
rotary_emb
=
Phi3
Su
ScaledRotaryEmbedding
(
rotary_emb
=
Phi3
LongRoPE
ScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment