Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2cd4d58d
Unverified
Commit
2cd4d58d
authored
May 24, 2025
by
ztang2370
Committed by
GitHub
May 24, 2025
Browse files
[Model] use AutoWeightsLoader for gpt2 (#18625)
Signed-off-by:
zt2370
<
ztang2370@gmail.com
>
parent
6d166a8d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
30 deletions
+43
-30
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+43
-30
No files found.
vllm/model_executor/models/gpt2.py
View file @
2cd4d58d
...
@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -235,6 +235,35 @@ class GPT2Model(nn.Module):
...
@@ -235,6 +235,35 @@ class GPT2Model(nn.Module):
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
".attn.bias"
in
name
or
".attn.masked_bias"
in
name
:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for
conv1d_weight_name
in
[
"c_attn"
,
"c_proj"
,
"c_fc"
]:
if
conv1d_weight_name
not
in
name
:
continue
if
not
name
.
endswith
(
".weight"
):
continue
loaded_weight
=
loaded_weight
.
t
()
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
GPT2LMHeadModel
(
nn
.
Module
,
SupportsPP
):
class
GPT2LMHeadModel
(
nn
.
Module
,
SupportsPP
):
...
@@ -283,32 +312,16 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
...
@@ -283,32 +312,16 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loader
=
AutoWeightsLoader
(
self
)
loaded_params
:
set
[
str
]
=
set
()
weights
=
_add_transformer_prefix
(
weights
)
for
name
,
loaded_weight
in
weights
:
return
loader
.
load_weights
(
weights
)
if
".attn.bias"
in
name
or
".attn.masked_bias"
in
name
:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
def
_add_transformer_prefix
(
continue
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]
if
not
name
.
startswith
(
"transformer."
)
and
not
name
.
startswith
(
)
->
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]:
"lm_head"
):
for
name
,
tensor
in
weights
:
name
=
"transformer."
+
name
if
not
name
.
startswith
(
'transformer.'
)
and
not
name
.
startswith
(
"lm_head"
):
if
is_pp_missing_parameter
(
name
,
self
):
name
=
'transformer.'
+
name
continue
yield
name
,
tensor
param
=
params_dict
[
name
]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for
conv1d_weight_name
in
[
"c_attn"
,
"c_proj"
,
"c_fc"
]:
if
conv1d_weight_name
not
in
name
:
continue
if
not
name
.
endswith
(
".weight"
):
continue
loaded_weight
=
loaded_weight
.
t
()
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment