Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
59a1eb59
Unverified
Commit
59a1eb59
authored
Jun 18, 2024
by
Shukant Pal
Committed by
GitHub
Jun 19, 2024
Browse files
[Bugfix] Fix Phi-3 Long RoPE scaling implementation (#5628)
parent
6820724e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
4 deletions
+14
-4
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+14
-4
No files found.
vllm/model_executor/layers/rotary_embedding.py
View file @
59a1eb59
...
@@ -507,8 +507,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
...
@@ -507,8 +507,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
short_factor
:
List
[
float
],
short_factor
:
List
[
float
],
long_factor
:
List
[
float
],
long_factor
:
List
[
float
],
short_mscale
:
float
=
1.
1
,
short_mscale
:
float
=
1.
0
,
long_mscale
:
float
=
1.
225
,
long_mscale
:
float
=
1.
0
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -530,6 +530,16 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
...
@@ -530,6 +530,16 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self
.
short_mscale
=
short_mscale
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
self
.
long_mscale
=
long_mscale
scale
=
(
self
.
max_position_embeddings
/
self
.
original_max_position_embeddings
)
if
scale
<=
1.0
:
self
.
scaling_factor
=
1.0
else
:
self
.
scaling_factor
=
math
.
sqrt
(
1
+
math
.
log
(
scale
)
/
math
.
log
(
self
.
original_max_position_embeddings
))
short_cache
=
self
.
_compute_cos_sin_cache
(
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
dtype
)
short_cache
=
short_cache
.
to
(
dtype
)
...
@@ -565,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
...
@@ -565,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
mscale
cos
=
freqs
.
cos
()
*
mscale
*
self
.
scaling_factor
sin
=
freqs
.
sin
()
*
mscale
sin
=
freqs
.
sin
()
*
mscale
*
self
.
scaling_factor
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
return
cache
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment