Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c5df56f8
Unverified
Commit
c5df56f8
authored
Jul 18, 2024
by
Simon Mo
Committed by
GitHub
Jul 19, 2024
Browse files
Add support for a rope extension method (#6553)
parent
1689219e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
4 deletions
+48
-4
vllm/config.py
vllm/config.py
+12
-2
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+36
-2
No files found.
vllm/config.py
View file @
c5df56f8
...
...
@@ -151,6 +151,15 @@ class ModelConfig:
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
if
(
getattr
(
self
.
hf_config
,
"max_position_embeddings"
,
0
)
==
131072
and
getattr
(
self
.
hf_config
,
"rope_scaling"
,
None
)
is
None
):
# Note(simon): this is a special case for a model that doesn't
# supply rope_scaling. We should remove this once the model is
# updated.
self
.
hf_config
.
update
({
"rope_scaling"
:
{
"type"
:
"extended"
,
}})
if
(
not
self
.
disable_sliding_window
and
self
.
hf_text_config
.
model_type
==
"gemma2"
and
self
.
hf_text_config
.
sliding_window
is
not
None
):
...
...
@@ -1442,8 +1451,9 @@ def _get_and_verify_max_len(
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
\
and
rope_scaling
[
"type"
]
!=
"longrope"
:
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
not
in
{
"su"
,
"longrope"
,
"extended"
}:
if
disable_sliding_window
:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
c5df56f8
...
...
@@ -733,6 +733,36 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
return
inv_freq
class
ExtendedRotaryEmbedding
(
RotaryEmbedding
):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
return
self
.
apply_scaling
(
inv_freqs
)
def
apply_scaling
(
self
,
freqs
:
torch
.
Tensor
):
scale_factor
=
8
low_freq_factor
=
1
high_freq_factor
=
4
old_context_len
=
8192
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
new_freqs
=
[]
for
freq
in
freqs
:
wavelen
=
2
*
math
.
pi
/
freq
if
wavelen
<
high_freq_wavelen
:
new_freqs
.
append
(
freq
)
elif
wavelen
>
low_freq_wavelen
:
new_freqs
.
append
(
freq
/
scale_factor
)
else
:
assert
low_freq_wavelen
!=
high_freq_wavelen
smooth
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
new_freqs
.
append
((
1
-
smooth
)
*
freq
/
scale_factor
+
smooth
*
freq
)
return
torch
.
tensor
(
new_freqs
,
dtype
=
freqs
.
dtype
,
device
=
freqs
.
device
)
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
...
@@ -767,9 +797,13 @@ def get_rope(
scaling_type
=
rope_scaling
[
"type"
]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if
scaling_type
!=
"su"
and
scaling_type
!=
"longrope"
:
if
scaling_type
not
in
{
"su"
,
"longrope"
,
"extended"
}
:
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
if
scaling_type
==
"extended"
:
rotary_emb
=
ExtendedRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
elif
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment