Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
7878958c
Unverified
Commit
7878958c
authored
Jan 13, 2024
by
Gary Hui
Committed by
GitHub
Jan 12, 2024
Browse files
Address Phi modeling update 2 (#2428)
parent
ce036244
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
62 deletions
+60
-62
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-1
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+59
-61
No files found.
vllm/model_executor/models/__init__.py
View file @
7878958c
...
@@ -33,7 +33,7 @@ _MODELS = {
...
@@ -33,7 +33,7 @@ _MODELS = {
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"PhiForCausalLM"
:
(
"phi
_1_5
"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"YiForCausalLM"
:
(
"yi"
,
"YiForCausalLM"
),
"YiForCausalLM"
:
(
"yi"
,
"YiForCausalLM"
),
...
...
vllm/model_executor/models/phi.py
View file @
7878958c
...
@@ -62,20 +62,6 @@ from vllm.sequence import SamplerOutput
...
@@ -62,20 +62,6 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
PhiEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
):
return
self
.
wte
(
input_ids
)
class
PhiAttention
(
nn
.
Module
):
class
PhiAttention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -93,27 +79,22 @@ class PhiAttention(nn.Module):
...
@@ -93,27 +79,22 @@ class PhiAttention(nn.Module):
tensor_model_parallel_world_size
)
tensor_model_parallel_world_size
)
# pylint: disable=C0103
# pylint: disable=C0103
self
.
Wqkv
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_size
,
self
.
total_num_heads
,
linear_method
=
linear_method
,
)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
config
.
hidden_size
,
self
.
hidden_size
,
self
.
head_size
,
self
.
head_size
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
Fals
e
,
bias
=
Tru
e
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
out_proj
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
scaling
=
self
.
head_size
**-
0.5
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
config
.
rotary_dim
rotary_dim
=
int
(
config
.
partial_rotary_factor
*
(
config
.
hidden_size
//
config
.
num_attention_heads
))
assert
rotary_dim
%
2
==
0
assert
rotary_dim
%
2
==
0
# pylint: disable=C0301
# pylint: disable=C0301
...
@@ -136,12 +117,12 @@ class PhiAttention(nn.Module):
...
@@ -136,12 +117,12 @@ class PhiAttention(nn.Module):
kv_cache
:
KVCache
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W
qkv
(
hidden_states
)
qkv
,
_
=
self
.
qkv
_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
return
output
...
@@ -166,8 +147,7 @@ class PhiMLP(nn.Module):
...
@@ -166,8 +147,7 @@ class PhiMLP(nn.Module):
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
n_inner
)
n_inner
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
...
@@ -182,9 +162,9 @@ class PhiLayer(nn.Module):
...
@@ -182,9 +162,9 @@ class PhiLayer(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
ilon
)
eps
=
config
.
layer_norm_eps
)
self
.
mixer
=
PhiAttention
(
config
,
linear_method
)
self
.
self_attn
=
PhiAttention
(
config
,
linear_method
)
self
.
mlp
=
PhiMLP
(
config
,
linear_method
)
self
.
mlp
=
PhiMLP
(
config
,
linear_method
)
def
forward
(
def
forward
(
...
@@ -195,8 +175,8 @@ class PhiLayer(nn.Module):
...
@@ -195,8 +175,8 @@ class PhiLayer(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
ln
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
attn_outputs
=
self
.
mixer
(
attn_outputs
=
self
.
self_attn
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
...
@@ -215,11 +195,14 @@ class PhiModel(nn.Module):
...
@@ -215,11 +195,14 @@ class PhiModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
embd
=
PhiEmbedding
(
config
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
h
=
nn
.
ModuleList
([
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
PhiLayer
(
config
,
linear_method
)
PhiLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -228,27 +211,19 @@ class PhiModel(nn.Module):
...
@@ -228,27 +211,19 @@ class PhiModel(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
emb
d
(
input_ids
)
hidden_states
=
self
.
emb
ed_tokens
(
input_ids
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
layer
=
self
.
h
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
],
input_metadata
,
input_metadata
,
)
)
return
hidden_states
class
PhiCausalLMHead
(
nn
.
Module
):
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
def
__init__
(
self
,
config
:
PretrainedConfig
):
return
hidden_states
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
linear
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
class
PhiForCausalLM
(
nn
.
Module
):
class
PhiForCausalLM
(
nn
.
Module
):
...
@@ -260,8 +235,11 @@ class PhiForCausalLM(nn.Module):
...
@@ -260,8 +235,11 @@ class PhiForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
transformer
=
PhiModel
(
config
,
linear_method
)
self
.
model
=
PhiModel
(
config
,
linear_method
)
self
.
lm_head
=
PhiCausalLMHead
(
config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
def
forward
(
...
@@ -271,9 +249,9 @@ class PhiForCausalLM(nn.Module):
...
@@ -271,9 +249,9 @@ class PhiForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
input_metadata
)
hidden_states
=
self
.
lm_head
.
ln
(
hidden_states
)
return
hidden_states
return
hidden_states
def
sample
(
def
sample
(
...
@@ -281,7 +259,7 @@ class PhiForCausalLM(nn.Module):
...
@@ -281,7 +259,7 @@ class PhiForCausalLM(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
head
=
self
.
lm_head
.
linear
head
=
self
.
lm_head
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
sampling_metadata
,
head
.
bias
)
return
next_tokens
return
next_tokens
...
@@ -291,17 +269,37 @@ class PhiForCausalLM(nn.Module):
...
@@ -291,17 +269,37 @@ class PhiForCausalLM(nn.Module):
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
)
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
# Skip loading extra bias for GPTQ models.
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
weight_name
not
in
name
:
continue
continue
# pylint: disable=E1136
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
# Skip loading extra bias for GPTQ models.
weight_loader
=
getattr
(
param
,
"weight_loader"
,
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
default_weight_loader
)
continue
weight_loader
(
param
,
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# pylint: disable=E1136
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment