Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b8009cb0
Commit
b8009cb0
authored
Jul 22, 2019
by
Anish Moorthy
Browse files
Make PreTrainedModel.from_pretrained pass unused arguments to model
parent
2f869dc6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
9 deletions
+26
-9
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+26
-9
No files found.
pytorch_transformers/modeling_utils.py
View file @
b8009cb0
...
...
@@ -78,7 +78,7 @@ class PretrainedConfig(object):
self
.
to_json_file
(
output_config_file
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
input
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
args
,
**
kwargs
):
r
""" Instantiate a PretrainedConfig from a pre-trained model configuration.
Params:
...
...
@@ -105,6 +105,7 @@ class PretrainedConfig(object):
"""
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
return_unused_args
=
kwargs
.
pop
(
'return_unused_args'
,
False
)
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
...
...
@@ -148,6 +149,9 @@ class PretrainedConfig(object):
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config %s"
,
config
)
if
return_unused_args
:
return
config
,
kwargs
else
:
return
config
@
classmethod
...
...
@@ -305,7 +309,7 @@ class PreTrainedModel(nn.Module):
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
input
s
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_arg
s
,
**
kwargs
):
r
"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
...
...
@@ -336,9 +340,17 @@ class PreTrainedModel(nn.Module):
configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
**model_args**: (`optional`) Sequence:
All positional arguments will be passed to the underlying model's __init__ function
**kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``
Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
If config is None, then **kwargs will be passed to the model.
If said key is *not* present, then kwargs will be used to
override any keys shared with the default configuration for the
given pretrained_model_name_or_path, and only the unshared
key/value pairs will be passed to the model.
Examples::
...
...
@@ -359,7 +371,12 @@ class PreTrainedModel(nn.Module):
# Load config
if
config
is
None
:
config
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
config
,
model_kwargs
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
return_unused_args
=
True
,
**
kwargs
)
else
:
model_kwargs
=
kwargs
# Load model
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
...
...
@@ -400,7 +417,7 @@ class PreTrainedModel(nn.Module):
archive_file
,
resolved_archive_file
))
# Instantiate model.
model
=
cls
(
config
)
model
=
cls
(
config
,
*
model_args
,
**
model_kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment