Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
200a2ffa
Unverified
Commit
200a2ffa
authored
Aug 18, 2024
by
Woosuk Kwon
Committed by
GitHub
Aug 18, 2024
Browse files
[Misc] Refactor Llama3 RoPE initialization (#7637)
parent
40e1360b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
29 deletions
+53
-29
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+53
-29
No files found.
vllm/model_executor/layers/rotary_embedding.py
View file @
200a2ffa
...
@@ -734,34 +734,50 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
...
@@ -734,34 +734,50 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
return
inv_freq
return
inv_freq
class
ExtendedRotaryEmbedding
(
RotaryEmbedding
):
class
Llama3RotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
scaling_factor
:
float
,
low_freq_factor
:
float
,
high_freq_factor
:
float
,
orig_max_position
:
int
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
low_freq_factor
=
low_freq_factor
self
.
high_freq_factor
=
high_freq_factor
self
.
orig_max_position
=
orig_max_position
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
return
self
.
apply_scaling
(
inv_freqs
)
low_freq_wavelen
=
self
.
orig_max_position
/
self
.
low_freq_factor
high_freq_wavelen
=
self
.
orig_max_position
/
self
.
high_freq_factor
def
apply_scaling
(
self
,
freqs
:
torch
.
Tensor
):
scale_factor
=
8
wave_len
=
2
*
math
.
pi
/
inv_freqs
low_freq_factor
=
1
if
self
.
low_freq_factor
!=
self
.
high_freq_factor
:
high_freq_factor
=
4
smooth
=
(
self
.
orig_max_position
/
wave_len
-
self
.
low_freq_factor
old_context_len
=
8192
)
/
(
self
.
high_freq_factor
-
self
.
low_freq_factor
)
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
new_freqs
=
[]
for
freq
in
freqs
:
wavelen
=
2
*
math
.
pi
/
freq
if
wavelen
<
high_freq_wavelen
:
new_freqs
.
append
(
freq
)
elif
wavelen
>
low_freq_wavelen
:
new_freqs
.
append
(
freq
/
scale_factor
)
else
:
else
:
assert
low_freq_wavelen
!=
high_freq_wavelen
smooth
=
0
smooth
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
new_freqs
=
torch
.
where
(
high_freq_factor
-
low_freq_factor
)
wave_len
<
high_freq_wavelen
,
new_freqs
.
append
((
1
-
smooth
)
*
freq
/
scale_factor
+
inv_freqs
,
smooth
*
freq
)
torch
.
where
(
return
torch
.
tensor
(
new_freqs
,
dtype
=
freqs
.
dtype
,
device
=
freqs
.
device
)
wave_len
>
low_freq_wavelen
,
inv_freqs
/
self
.
scaling_factor
,
(
1
-
smooth
)
*
inv_freqs
/
self
.
scaling_factor
+
smooth
*
inv_freqs
,
),
)
return
new_freqs
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
@@ -794,6 +810,7 @@ def get_rope(
...
@@ -794,6 +810,7 @@ def get_rope(
rope_scaling_args
,
dtype
)
rope_scaling_args
,
dtype
)
if
key
in
_ROPE_DICT
:
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
is_neox_style
,
dtype
)
...
@@ -802,12 +819,19 @@ def get_rope(
...
@@ -802,12 +819,19 @@ def get_rope(
"type"
]
if
"type"
in
rope_scaling
else
rope_scaling
[
"rope_type"
]
"type"
]
if
"type"
in
rope_scaling
else
rope_scaling
[
"rope_type"
]
# The correct one should be "longrope" but keep "su" here
# The correct one should be "longrope" but keep "su" here
# for backward compatible
# for backward compatible
if
scaling_type
not
in
{
"su"
,
"longrope"
,
"llama3"
}:
if
scaling_type
not
in
{
"su"
,
"longrope"
}:
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"llama3"
:
if
scaling_type
==
"llama3"
:
rotary_emb
=
ExtendedRotaryEmbedding
(
head_size
,
rotary_dim
,
low_freq_factor
=
rope_scaling
[
"low_freq_factor"
]
high_freq_factor
=
rope_scaling
[
"high_freq_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
rotary_emb
=
Llama3RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
max_position
,
base
,
is_neox_style
,
dtype
)
is_neox_style
,
dtype
,
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position
)
elif
scaling_type
==
"linear"
:
elif
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
max_position
,
base
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment