Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
162ba383
Commit
162ba383
authored
Jul 05, 2019
by
thomwolf
Browse files
fix model loading
parent
6dacc79d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
4 deletions
+28
-4
examples/run_bert_classifier.py
examples/run_bert_classifier.py
+2
-1
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+21
-1
pytorch_transformers/tests/modeling_utils_test.py
pytorch_transformers/tests/modeling_utils_test.py
+5
-2
No files found.
examples/run_bert_classifier.py
View file @
162ba383
...
...
@@ -308,7 +308,8 @@ def main():
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
# define a new function to compute loss values for both output_modes
logits
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
)
ouputs
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
labels
=
label_ids
)
loss
=
if
output_mode
==
"classification"
:
loss_fct
=
CrossEntropyLoss
()
...
...
pytorch_transformers/modeling_utils.py
View file @
162ba383
...
...
@@ -193,7 +193,8 @@ class PreTrainedModel(nn.Module):
"""
state_dict
=
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
pop
(
'from_tf'
,
None
)
from_tf
=
kwargs
.
pop
(
'from_tf'
,
False
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
# Load config
config
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
...
...
@@ -239,6 +240,21 @@ class PreTrainedModel(nn.Module):
# Directly load from a TensorFlow checkpoint
return
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
# Convert old format to new format if needed from a PyTorch state_dict
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
'gamma'
in
key
:
new_key
=
key
.
replace
(
'gamma'
,
'weight'
)
if
'beta'
in
key
:
new_key
=
key
.
replace
(
'beta'
,
'bias'
)
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
...
...
@@ -279,6 +295,10 @@ class PreTrainedModel(nn.Module):
if
hasattr
(
model
,
'tie_weights'
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
if
output_loading_info
:
loading_info
=
{
"missing_keys"
:
missing_keys
,
"unexpected_keys"
:
unexpected_keys
,
"error_msgs"
:
error_msgs
}
return
model
,
loading_info
return
model
...
...
pytorch_transformers/tests/modeling_utils_test.py
View file @
162ba383
...
...
@@ -17,21 +17,24 @@ from __future__ import division
from
__future__
import
print_function
import
unittest
import
logging
from
pytorch_transformers
import
PretrainedConfig
,
PreTrainedModel
from
pytorch_transformers.modeling_bert
import
BertModel
,
BertConfig
,
PRETRAINED_MODEL_ARCHIVE_MAP
class
ModelUtilsTest
(
unittest
.
TestCase
):
def
test_model_from_pretrained
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
config
=
BertConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
PretrainedConfig
)
model
=
BertModel
.
from_pretrained
(
model_name
)
model
,
loading_info
=
BertModel
.
from_pretrained
(
model_name
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
PreTrainedModel
)
for
value
in
loading_info
.
values
():
self
.
assertEqual
(
len
(
value
),
0
)
config
=
BertConfig
.
from_pretrained
(
model_name
,
output_attentions
=
True
,
output_hidden_states
=
True
)
model
=
BertModel
.
from_pretrained
(
model_name
,
output_attentions
=
True
,
output_hidden_states
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment