Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
162ba383
Commit
162ba383
authored
Jul 05, 2019
by
thomwolf
Browse files
fix model loading
parent
6dacc79d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
4 deletions
+28
-4
examples/run_bert_classifier.py
examples/run_bert_classifier.py
+2
-1
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+21
-1
pytorch_transformers/tests/modeling_utils_test.py
pytorch_transformers/tests/modeling_utils_test.py
+5
-2
No files found.
examples/run_bert_classifier.py
View file @
162ba383
...
...
@@ -308,7 +308,8 @@ def main():
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
# define a new function to compute loss values for both output_modes
logits
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
)
ouputs
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
labels
=
label_ids
)
loss
=
if
output_mode
==
"classification"
:
loss_fct
=
CrossEntropyLoss
()
...
...
pytorch_transformers/modeling_utils.py
View file @
162ba383
...
...
@@ -193,7 +193,8 @@ class PreTrainedModel(nn.Module):
"""
state_dict
=
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
pop
(
'from_tf'
,
None
)
from_tf
=
kwargs
.
pop
(
'from_tf'
,
False
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
# Load config
config
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
...
...
@@ -239,6 +240,21 @@ class PreTrainedModel(nn.Module):
# Directly load from a TensorFlow checkpoint
return
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
# Convert old format to new format if needed from a PyTorch state_dict
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
'gamma'
in
key
:
new_key
=
key
.
replace
(
'gamma'
,
'weight'
)
if
'beta'
in
key
:
new_key
=
key
.
replace
(
'beta'
,
'bias'
)
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
...
...
@@ -279,6 +295,10 @@ class PreTrainedModel(nn.Module):
if
hasattr
(
model
,
'tie_weights'
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
if
output_loading_info
:
loading_info
=
{
"missing_keys"
:
missing_keys
,
"unexpected_keys"
:
unexpected_keys
,
"error_msgs"
:
error_msgs
}
return
model
,
loading_info
return
model
...
...
pytorch_transformers/tests/modeling_utils_test.py
View file @
162ba383
...
...
@@ -17,21 +17,24 @@ from __future__ import division
from
__future__
import
print_function
import
unittest
import
logging
from
pytorch_transformers
import
PretrainedConfig
,
PreTrainedModel
from
pytorch_transformers.modeling_bert
import
BertModel
,
BertConfig
,
PRETRAINED_MODEL_ARCHIVE_MAP
class
ModelUtilsTest
(
unittest
.
TestCase
):
def
test_model_from_pretrained
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
config
=
BertConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
PretrainedConfig
)
model
=
BertModel
.
from_pretrained
(
model_name
)
model
,
loading_info
=
BertModel
.
from_pretrained
(
model_name
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
PreTrainedModel
)
for
value
in
loading_info
.
values
():
self
.
assertEqual
(
len
(
value
),
0
)
config
=
BertConfig
.
from_pretrained
(
model_name
,
output_attentions
=
True
,
output_hidden_states
=
True
)
model
=
BertModel
.
from_pretrained
(
model_name
,
output_attentions
=
True
,
output_hidden_states
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment