Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
33e0823d
Unverified
Commit
33e0823d
authored
May 17, 2024
by
Jinzhen Lin
Committed by
GitHub
May 17, 2024
Browse files
[Bugfix] fix rope error when load models with different dtypes (#4835)
parent
26148120
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
13 deletions
+64
-13
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+43
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+21
-12
No files found.
tests/kernels/test_pos_encoding.py
View file @
33e0823d
from
itertools
import
accumulate
from
itertools
import
accumulate
,
product
from
typing
import
List
,
Optional
import
pytest
...
...
@@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora(
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
@
torch
.
inference_mode
()
def
test_rope_module_cache
():
MAX_POSITIONS
=
[
123
,
1234
]
BASES
=
[
10000
,
1000000
]
ROPE_SCALINGS
=
[
None
,
{
"type"
:
"linear"
,
"factor"
:
(
1
,
)
},
{
"type"
:
"dynamic"
,
"factor"
:
1
}
]
settings
=
[
HEAD_SIZES
,
ROTARY_DIMS
,
MAX_POSITIONS
,
BASES
,
IS_NEOX_STYLE
,
ROPE_SCALINGS
,
DTYPES
]
rope_setting_id_map
=
{}
for
setting
in
product
(
*
settings
):
head_size
,
rotary_dim
,
max_position
,
base
,
\
is_neox_stype
,
rope_scaling
,
dtype
=
setting
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_stype
,
rope_scaling
,
dtype
)
# different settings cannot share the same rope module
assert
id
(
rope
)
not
in
rope_setting_id_map
.
values
()
assert
all
(
x
.
dtype
==
dtype
for
x
in
rope
.
buffers
())
assert
all
(
x
.
dtype
==
dtype
for
x
in
rope
.
parameters
())
rope_setting_id_map
[
str
(
setting
)]
=
id
(
rope
)
for
setting
in
product
(
*
settings
):
head_size
,
rotary_dim
,
max_position
,
base
,
\
is_neox_stype
,
rope_scaling
,
dtype
=
setting
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_stype
,
rope_scaling
,
dtype
)
# check if cache take effect
assert
id
(
rope
)
==
rope_setting_id_map
[
str
(
setting
)]
vllm/model_executor/layers/rotary_embedding.py
View file @
33e0823d
...
...
@@ -53,6 +53,7 @@ class RotaryEmbedding(nn.Module):
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
...
...
@@ -62,7 +63,7 @@ class RotaryEmbedding(nn.Module):
self
.
is_neox_style
=
is_neox_style
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
torch
.
get_default_
dtype
()
)
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
...
...
@@ -178,12 +179,13 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
base
:
int
,
is_neox_style
:
bool
,
scaling_factors
:
Union
[
List
[
float
],
float
],
dtype
:
torch
.
dtype
,
)
->
None
:
if
isinstance
(
scaling_factors
,
float
):
scaling_factors
=
[
scaling_factors
]
self
.
scaling_factors
=
scaling_factors
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
)
is_neox_style
,
dtype
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
...
...
@@ -219,10 +221,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
base
:
int
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
)
is_neox_style
,
dtype
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
# NOTE(woosuk): self.max_position_embeddings is the original
...
...
@@ -299,6 +302,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
base
:
int
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
...
...
@@ -314,7 +318,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self
.
mscale
=
float
(
_yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
)
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
...
...
@@ -359,6 +363,7 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
original_max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
short_factor
:
List
[
float
],
long_factor
:
List
[
float
],
short_mscale
:
float
=
1.1
,
...
...
@@ -385,14 +390,14 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
torch
.
get_default_
dtype
()
)
short_cache
=
short_cache
.
to
(
dtype
)
self
.
register_buffer
(
"short_cos_sin_cache"
,
short_cache
,
persistent
=
False
)
long_cache
=
self
.
_compute_cos_sin_cache
(
max_position_embeddings
,
long_factor
,
long_mscale
)
long_cache
=
long_cache
.
to
(
torch
.
get_default_
dtype
()
)
long_cache
=
long_cache
.
to
(
dtype
)
self
.
register_buffer
(
"long_cos_sin_cache"
,
long_cache
,
persistent
=
False
)
...
...
@@ -463,7 +468,10 @@ def get_rope(
base
:
int
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
if
rope_scaling
is
not
None
:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple
=
{
...
...
@@ -474,12 +482,12 @@ def get_rope(
else
:
rope_scaling_args
=
None
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
)
rope_scaling_args
,
dtype
)
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
is_neox_style
,
dtype
)
else
:
scaling_type
=
rope_scaling
[
"type"
]
if
scaling_type
!=
"su"
:
...
...
@@ -488,11 +496,11 @@ def get_rope(
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
)
scaling_factor
,
dtype
)
elif
scaling_type
==
"dynamic"
:
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
)
scaling_factor
,
dtype
)
elif
scaling_type
==
"yarn"
:
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
...
...
@@ -505,7 +513,7 @@ def get_rope(
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
elif
scaling_type
==
"su"
:
short_factor
=
rope_scaling
[
"short_factor"
]
...
...
@@ -519,7 +527,8 @@ def get_rope(
}
rotary_emb
=
Phi3SuScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
short_factor
,
long_factor
,
**
extra_kwargs
)
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment