Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96e90fde
Unverified
Commit
96e90fde
authored
Apr 25, 2024
by
Caio Mendes
Committed by
GitHub
Apr 25, 2024
Browse files
[Model] Adds Phi-3 support (#4298)
parent
a395a638
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
149 additions
and
9 deletions
+149
-9
README.md
README.md
+1
-0
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+4
-0
vllm/config.py
vllm/config.py
+1
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+133
-3
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+9
-5
No files found.
README.md
View file @
96e90fde
...
@@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
...
@@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
-
OPT (
`facebook/opt-66b`
,
`facebook/opt-iml-max-30b`
, etc.)
-
OPT (
`facebook/opt-66b`
,
`facebook/opt-iml-max-30b`
, etc.)
-
Orion (
`OrionStarAI/Orion-14B-Base`
,
`OrionStarAI/Orion-14B-Chat`
, etc.)
-
Orion (
`OrionStarAI/Orion-14B-Base`
,
`OrionStarAI/Orion-14B-Chat`
, etc.)
-
Phi (
`microsoft/phi-1_5`
,
`microsoft/phi-2`
, etc.)
-
Phi (
`microsoft/phi-1_5`
,
`microsoft/phi-2`
, etc.)
-
Phi3 (
`microsoft/Phi-3-mini-4k-instruct`
,
`microsoft/Phi-3-mini-128k-instruct`
, etc.)
-
Qwen (
`Qwen/Qwen-7B`
,
`Qwen/Qwen-7B-Chat`
, etc.)
-
Qwen (
`Qwen/Qwen-7B`
,
`Qwen/Qwen-7B-Chat`
, etc.)
-
Qwen2 (
`Qwen/Qwen1.5-7B`
,
`Qwen/Qwen1.5-7B-Chat`
, etc.)
-
Qwen2 (
`Qwen/Qwen1.5-7B`
,
`Qwen/Qwen1.5-7B-Chat`
, etc.)
-
Qwen2MoE (
`Qwen/Qwen1.5-MoE-A2.7B`
,
`Qwen/Qwen1.5-MoE-A2.7B-Chat`
, etc.)
-
Qwen2MoE (
`Qwen/Qwen1.5-MoE-A2.7B`
,
`Qwen/Qwen1.5-MoE-A2.7B-Chat`
, etc.)
...
...
docs/source/models/supported_models.rst
View file @
96e90fde
...
@@ -115,6 +115,10 @@ Alongside each architecture, we include some popular models that use it.
...
@@ -115,6 +115,10 @@ Alongside each architecture, we include some popular models that use it.
- Phi
- Phi
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
-
-
* - :code:`Phi3ForCausalLM`
- Phi-3
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
-
* - :code:`QWenLMHeadModel`
* - :code:`QWenLMHeadModel`
- Qwen
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
...
...
vllm/config.py
View file @
96e90fde
...
@@ -1056,7 +1056,7 @@ def _get_and_verify_max_len(
...
@@ -1056,7 +1056,7 @@ def _get_and_verify_max_len(
derived_max_model_len
=
default_max_len
derived_max_model_len
=
default_max_len
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
:
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
:
assert
"factor"
in
rope_scaling
assert
"factor"
in
rope_scaling
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
rope_scaling
[
"type"
]
==
"yarn"
:
if
rope_scaling
[
"type"
]
==
"yarn"
:
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
96e90fde
...
@@ -338,6 +338,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -338,6 +338,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return
cache
return
cache
class
Phi3SuScaledRotaryEmbedding
(
nn
.
Module
):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
original_max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
short_factor
:
List
[
float
],
long_factor
:
List
[
float
],
short_mscale
:
float
=
1.1
,
long_mscale
:
float
=
1.225
,
):
super
().
__init__
()
if
rotary_dim
!=
head_size
:
raise
ValueError
(
f
"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim !=
\
head_size (
{
rotary_dim
}
!=
{
head_size
}
)."
)
if
is_neox_style
is
False
:
raise
ValueError
(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style."
)
self
.
head_size
=
head_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
original_max_position_embeddings
=
original_max_position_embeddings
self
.
base
=
base
self
.
short_factor
=
short_factor
self
.
long_factor
=
long_factor
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
torch
.
get_default_dtype
())
self
.
register_buffer
(
"short_cos_sin_cache"
,
short_cache
,
persistent
=
False
)
long_cache
=
self
.
_compute_cos_sin_cache
(
max_position_embeddings
,
long_factor
,
long_mscale
)
long_cache
=
long_cache
.
to
(
torch
.
get_default_dtype
())
self
.
register_buffer
(
"long_cos_sin_cache"
,
long_cache
,
persistent
=
False
)
long_short_cache
=
torch
.
cat
(
[
self
.
short_cos_sin_cache
,
self
.
long_cos_sin_cache
],
dim
=
0
)
self
.
register_buffer
(
"long_short_cos_sin_cache"
,
long_short_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
rescale_factors
:
List
[
float
])
->
torch
.
Tensor
:
rescale_factors
=
torch
.
tensor
(
rescale_factors
,
dtype
=
torch
.
float32
)
inv_freq
=
1.0
/
(
rescale_factors
*
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
head_size
,
2
,
dtype
=
torch
.
float
)
/
self
.
head_size
)))
return
inv_freq
def
_compute_cos_sin_cache
(
self
,
max_position_embeddings
:
int
,
rescale_factors
:
List
[
float
],
mscale
:
float
,
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
mscale
sin
=
freqs
.
sin
()
*
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
k
=
self
.
original_max_position_embeddings
long_prompt_offset
=
(
torch
.
any
(
positions
>
k
).
float
()
*
torch
.
full_like
(
positions
,
k
)).
long
()
idx
=
(
torch
.
add
(
positions
,
long_prompt_offset
)
if
long_prompt_offset
is
not
None
else
positions
)
self
.
long_short_cos_sin_cache
=
self
.
long_short_cos_sin_cache
.
to
(
idx
.
device
)
idx
=
torch
.
add
(
idx
,
offsets
)
if
offsets
is
not
None
else
idx
cos_sin
=
torch
.
index_select
(
self
.
long_short_cos_sin_cache
,
0
,
idx
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
=
cos
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
query
=
query
*
cos
+
_rotate_neox
(
query
)
*
sin
key
=
key
*
cos
+
_rotate_neox
(
key
)
*
sin
return
query
.
flatten
(
-
2
),
key
.
flatten
(
-
2
)
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
@@ -349,17 +457,26 @@ def get_rope(
...
@@ -349,17 +457,26 @@ def get_rope(
is_neox_style
:
bool
=
True
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
)
->
RotaryEmbedding
:
if
rope_scaling
is
not
None
:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
rope_scaling
.
items
()
}
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
tuple
(
rope_scaling
.
items
())
if
rope_scaling
is
not
None
else
None
)
rope_scaling
_args
)
if
key
in
_ROPE_DICT
:
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
return
_ROPE_DICT
[
key
]
if
rope_scaling
is
None
:
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
is_neox_style
)
else
:
else
:
scaling_type
=
rope_scaling
[
"type"
]
scaling_type
=
rope_scaling
[
"type"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
!=
"su"
:
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
if
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
max_position
,
base
,
...
@@ -383,6 +500,19 @@ def get_rope(
...
@@ -383,6 +500,19 @@ def get_rope(
base
,
is_neox_style
,
base
,
is_neox_style
,
scaling_factor
,
scaling_factor
,
**
extra_kwargs
)
**
extra_kwargs
)
elif
scaling_type
==
"su"
:
short_factor
=
rope_scaling
[
"short_factor"
]
long_factor
=
rope_scaling
[
"long_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
rotary_emb
=
Phi3SuScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
short_factor
,
long_factor
,
**
extra_kwargs
)
else
:
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
_ROPE_DICT
[
key
]
=
rotary_emb
...
...
vllm/model_executor/models/__init__.py
View file @
96e90fde
...
@@ -46,6 +46,7 @@ _MODELS = {
...
@@ -46,6 +46,7 @@ _MODELS = {
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
...
...
vllm/model_executor/models/llama.py
View file @
96e90fde
...
@@ -180,6 +180,10 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -180,6 +180,10 @@ class LlamaDecoderLayer(nn.Module):
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
8192
)
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
)
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
)
...
@@ -378,11 +382,11 @@ class LlamaForCausalLM(nn.Module):
...
@@ -378,11 +382,11 @@ class LlamaForCausalLM(nn.Module):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"
.
qkv_proj"
,
"
.
q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"
.
qkv_proj"
,
"
.
k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"
.
qkv_proj"
,
"
.
v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"
.
gate_up_proj"
,
"
.
gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"
.
gate_up_proj"
,
"
.
up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment