Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
7878958c
Unverified
Commit
7878958c
authored
Jan 13, 2024
by
Gary Hui
Committed by
GitHub
Jan 12, 2024
Browse files
Address Phi modeling update 2 (#2428)
parent
ce036244
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
62 deletions
+60
-62
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-1
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+59
-61
No files found.
vllm/model_executor/models/__init__.py
View file @
7878958c
...
...
@@ -33,7 +33,7 @@ _MODELS = {
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"PhiForCausalLM"
:
(
"phi
_1_5
"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"YiForCausalLM"
:
(
"yi"
,
"YiForCausalLM"
),
...
...
vllm/model_executor/models/phi.py
View file @
7878958c
...
...
@@ -62,20 +62,6 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
PhiEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
):
return
self
.
wte
(
input_ids
)
class
PhiAttention
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -93,27 +79,22 @@ class PhiAttention(nn.Module):
tensor_model_parallel_world_size
)
# pylint: disable=C0103
self
.
Wqkv
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_size
,
self
.
total_num_heads
,
linear_method
=
linear_method
,
)
self
.
qkv_proj
=
QKVParallelLinear
(
config
.
hidden_size
,
self
.
hidden_size
,
self
.
head_size
,
self
.
total_num_heads
,
bias
=
Fals
e
,
bias
=
Tru
e
,
linear_method
=
linear_method
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
linear_method
=
linear_method
,
)
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
config
.
rotary_dim
rotary_dim
=
int
(
config
.
partial_rotary_factor
*
(
config
.
hidden_size
//
config
.
num_attention_heads
))
assert
rotary_dim
%
2
==
0
# pylint: disable=C0301
...
...
@@ -136,12 +117,12 @@ class PhiAttention(nn.Module):
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W
qkv
(
hidden_states
)
qkv
,
_
=
self
.
qkv
_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -166,8 +147,7 @@ class PhiMLP(nn.Module):
linear_method
=
linear_method
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
n_inner
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
n_inner
)
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
...
...
@@ -182,9 +162,9 @@ class PhiLayer(nn.Module):
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
ilon
)
self
.
mixer
=
PhiAttention
(
config
,
linear_method
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
self_attn
=
PhiAttention
(
config
,
linear_method
)
self
.
mlp
=
PhiMLP
(
config
,
linear_method
)
def
forward
(
...
...
@@ -195,8 +175,8 @@ class PhiLayer(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln
(
hidden_states
)
attn_outputs
=
self
.
mixer
(
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
attn_outputs
=
self
.
self_attn
(
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
...
...
@@ -215,11 +195,14 @@ class PhiModel(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
embd
=
PhiEmbedding
(
config
)
self
.
h
=
nn
.
ModuleList
([
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
PhiLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
...
...
@@ -228,27 +211,19 @@ class PhiModel(nn.Module):
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
emb
d
(
input_ids
)
hidden_states
=
self
.
emb
ed_tokens
(
input_ids
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
layer
=
self
.
h
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
)
return
hidden_states
class
PhiCausalLMHead
(
nn
.
Module
):
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
linear
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
return
hidden_states
class
PhiForCausalLM
(
nn
.
Module
):
...
...
@@ -260,8 +235,11 @@ class PhiForCausalLM(nn.Module):
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
PhiModel
(
config
,
linear_method
)
self
.
lm_head
=
PhiCausalLMHead
(
config
)
self
.
model
=
PhiModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
...
...
@@ -271,9 +249,9 @@ class PhiForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
hidden_states
=
self
.
lm_head
.
ln
(
hidden_states
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
return
hidden_states
def
sample
(
...
...
@@ -281,7 +259,7 @@ class PhiForCausalLM(nn.Module):
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
head
=
self
.
lm_head
.
linear
head
=
self
.
lm_head
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
return
next_tokens
...
...
@@ -291,17 +269,37 @@ class PhiForCausalLM(nn.Module):
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
)
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# pylint: disable=E1136
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# pylint: disable=E1136
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment