Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
270fa2f2
Commit
270fa2f2
authored
Dec 11, 2018
by
thomwolf
Browse files
add pretrained loading from state_dict
parent
174cdbcc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+7
-4
No files found.
pytorch_pretrained_bert/modeling.py
View file @
270fa2f2
...
@@ -445,9 +445,9 @@ class PreTrainedBertModel(nn.Module):
...
@@ -445,9 +445,9 @@ class PreTrainedBertModel(nn.Module):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name
,
state_dict
=
None
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Instantiate a PreTrainedBertModel from a pre-trained model file
or a pytorch state dict
.
Download and cache the pre-trained model file if needed.
Download and cache the pre-trained model file if needed.
Params:
Params:
...
@@ -461,6 +461,8 @@ class PreTrainedBertModel(nn.Module):
...
@@ -461,6 +461,8 @@ class PreTrainedBertModel(nn.Module):
- a path or url to a pretrained model archive containing:
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
(ex: num_labels for BertForSequenceClassification)
"""
"""
...
@@ -502,8 +504,9 @@ class PreTrainedBertModel(nn.Module):
...
@@ -502,8 +504,9 @@ class PreTrainedBertModel(nn.Module):
logger
.
info
(
"Model config {}"
.
format
(
config
))
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
if
state_dict
is
None
:
state_dict
=
torch
.
load
(
weights_path
)
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
state_dict
=
torch
.
load
(
weights_path
)
missing_keys
=
[]
missing_keys
=
[]
unexpected_keys
=
[]
unexpected_keys
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment