Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
0fbfc4b8
Unverified
Commit
0fbfc4b8
authored
Dec 15, 2023
by
CHU Tianxiang
Committed by
GitHub
Dec 15, 2023
Browse files
Add GPTQ support (#916)
parent
c06170cc
Changes
35
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
98 additions
and
30 deletions
+98
-30
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+3
-0
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+22
-17
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+0
-1
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+8
-1
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+0
-1
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+8
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+8
-1
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+8
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+9
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+3
-0
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+8
-1
vllm/model_executor/models/phi_1_5.py
vllm/model_executor/models/phi_1_5.py
+3
-0
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+8
-2
vllm/model_executor/models/yi.py
vllm/model_executor/models/yi.py
+8
-1
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+2
-1
No files found.
vllm/model_executor/models/chatglm.py
View file @
0fbfc4b8
...
@@ -377,6 +377,9 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -377,6 +377,9 @@ class ChatGLMForCausalLM(nn.Module):
continue
continue
if
"word_embeddings"
in
name
:
if
"word_embeddings"
in
name
:
name
=
name
.
replace
(
".word_embeddings"
,
""
)
name
=
name
.
replace
(
".word_embeddings"
,
""
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/falcon.py
View file @
0fbfc4b8
...
@@ -425,16 +425,21 @@ class FalconForCausalLM(nn.Module):
...
@@ -425,16 +425,21 @@ class FalconForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
if
"query_key_value"
in
name
:
if
"query_key_value"
in
name
:
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
loaded_weight_shape
=
loaded_weight
.
shape
loaded_weight_shape
=
loaded_weight
.
shape
if
output_dim
is
not
None
:
loaded_weight
=
loaded_weight
.
view
(
loaded_weight
=
loaded_weight
.
view
(
loaded_weight_shape
[:
output_dim
]
+
loaded_weight_shape
[:
output_dim
]
+
(
total_num_kv_heads
,
num_query_heads_per_kv_head
+
2
,
-
1
)
+
(
total_num_kv_heads
,
num_query_heads_per_kv_head
+
2
,
loaded_weight_shape
[
output_dim
+
1
:])
-
1
)
+
loaded_weight_shape
[
output_dim
+
1
:])
wq
=
loaded_weight
.
narrow
(
wq
=
loaded_weight
.
narrow
(
output_dim
+
1
,
0
,
num_query_heads_per_kv_head
).
reshape
(
output_dim
+
1
,
0
,
num_query_heads_per_kv_head
).
reshape
(
*
loaded_weight_shape
[:
output_dim
],
-
1
,
*
loaded_weight_shape
[:
output_dim
],
-
1
,
*
loaded_weight_shape
[
output_dim
+
1
:])
*
loaded_weight_shape
[
output_dim
+
1
:])
wk
=
loaded_weight
.
narrow
(
wk
=
loaded_weight
.
narrow
(
...
...
vllm/model_executor/models/gpt2.py
View file @
0fbfc4b8
...
@@ -275,7 +275,6 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -275,7 +275,6 @@ class GPT2LMHeadModel(nn.Module):
if
not
name
.
endswith
(
".weight"
):
if
not
name
.
endswith
(
".weight"
):
continue
continue
loaded_weight
=
loaded_weight
.
t
()
loaded_weight
=
loaded_weight
.
t
()
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/gpt_j.py
View file @
0fbfc4b8
...
@@ -274,11 +274,18 @@ class GPTJForCausalLM(nn.Module):
...
@@ -274,11 +274,18 @@ class GPTJForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/gpt_neox.py
View file @
0fbfc4b8
...
@@ -72,7 +72,6 @@ class GPTNeoXAttention(nn.Module):
...
@@ -72,7 +72,6 @@ class GPTNeoXAttention(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
scaling
=
self
.
head_size
**-
0.5
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
assert
rotary_dim
%
2
==
0
assert
rotary_dim
%
2
==
0
...
...
vllm/model_executor/models/internlm.py
View file @
0fbfc4b8
...
@@ -289,11 +289,18 @@ class InternLMForCausalLM(nn.Module):
...
@@ -289,11 +289,18 @@ class InternLMForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/llama.py
View file @
0fbfc4b8
...
@@ -330,11 +330,18 @@ class LlamaForCausalLM(nn.Module):
...
@@ -330,11 +330,18 @@ class LlamaForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/mistral.py
View file @
0fbfc4b8
...
@@ -321,11 +321,18 @@ class MistralForCausalLM(nn.Module):
...
@@ -321,11 +321,18 @@ class MistralForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/mixtral.py
View file @
0fbfc4b8
...
@@ -153,7 +153,7 @@ class MixtralMoE(nn.Module):
...
@@ -153,7 +153,7 @@ class MixtralMoE(nn.Module):
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
self
.
num_total_experts
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
linear_method
=
None
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
...
@@ -418,11 +418,18 @@ class MixtralForCausalLM(nn.Module):
...
@@ -418,11 +418,18 @@ class MixtralForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/mpt.py
View file @
0fbfc4b8
...
@@ -297,6 +297,9 @@ class MPTForCausalLM(nn.Module):
...
@@ -297,6 +297,9 @@ class MPTForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/opt.py
View file @
0fbfc4b8
...
@@ -345,11 +345,18 @@ class OPTForCausalLM(nn.Module):
...
@@ -345,11 +345,18 @@ class OPTForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/phi_1_5.py
View file @
0fbfc4b8
...
@@ -305,6 +305,9 @@ class PhiForCausalLM(nn.Module):
...
@@ -305,6 +305,9 @@ class PhiForCausalLM(nn.Module):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# pylint: disable=E1136
# pylint: disable=E1136
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
...
...
vllm/model_executor/models/qwen.py
View file @
0fbfc4b8
...
@@ -82,7 +82,6 @@ class QWenAttention(nn.Module):
...
@@ -82,7 +82,6 @@ class QWenAttention(nn.Module):
self
.
num_heads
=
(
self
.
total_num_heads
//
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
tensor_model_parallel_world_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
c_attn
=
QKVParallelLinear
(
self
.
c_attn
=
QKVParallelLinear
(
hidden_size
,
hidden_size
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -279,11 +278,18 @@ class QWenLMHeadModel(nn.Module):
...
@@ -279,11 +278,18 @@ class QWenLMHeadModel(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/yi.py
View file @
0fbfc4b8
...
@@ -320,11 +320,18 @@ class YiForCausalLM(nn.Module):
...
@@ -320,11 +320,18 @@ class YiForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/weight_utils.py
View file @
0fbfc4b8
...
@@ -287,4 +287,5 @@ def initialize_dummy_weights(
...
@@ -287,4 +287,5 @@ def initialize_dummy_weights(
values between -1e-3 and 1e-3 works well for most models.
values between -1e-3 and 1e-3 works well for most models.
"""
"""
for
param
in
model
.
state_dict
().
values
():
for
param
in
model
.
state_dict
().
values
():
if
torch
.
is_floating_point
(
param
):
param
.
data
.
uniform_
(
low
,
high
)
param
.
data
.
uniform_
(
low
,
high
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment