Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b8009cb0
Commit
b8009cb0
authored
Jul 22, 2019
by
Anish Moorthy
Browse files
Make PreTrainedModel.from_pretrained pass unused arguments to model
parent
2f869dc6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
9 deletions
+26
-9
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+26
-9
No files found.
pytorch_transformers/modeling_utils.py
View file @
b8009cb0
...
@@ -78,7 +78,7 @@ class PretrainedConfig(object):
...
@@ -78,7 +78,7 @@ class PretrainedConfig(object):
self
.
to_json_file
(
output_config_file
)
self
.
to_json_file
(
output_config_file
)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
input
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
args
,
**
kwargs
):
r
""" Instantiate a PretrainedConfig from a pre-trained model configuration.
r
""" Instantiate a PretrainedConfig from a pre-trained model configuration.
Params:
Params:
...
@@ -105,6 +105,7 @@ class PretrainedConfig(object):
...
@@ -105,6 +105,7 @@ class PretrainedConfig(object):
"""
"""
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
return_unused_args
=
kwargs
.
pop
(
'return_unused_args'
,
False
)
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
...
@@ -148,6 +149,9 @@ class PretrainedConfig(object):
...
@@ -148,6 +149,9 @@ class PretrainedConfig(object):
kwargs
.
pop
(
key
,
None
)
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config %s"
,
config
)
logger
.
info
(
"Model config %s"
,
config
)
if
return_unused_args
:
return
config
,
kwargs
else
:
return
config
return
config
@
classmethod
@
classmethod
...
@@ -305,7 +309,7 @@ class PreTrainedModel(nn.Module):
...
@@ -305,7 +309,7 @@ class PreTrainedModel(nn.Module):
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
input
s
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_arg
s
,
**
kwargs
):
r
"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
r
"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
...
@@ -336,9 +340,17 @@ class PreTrainedModel(nn.Module):
...
@@ -336,9 +340,17 @@ class PreTrainedModel(nn.Module):
configuration should be cached if the standard cache should not be used.
configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean:
**output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
**model_args**: (`optional`) Sequence:
All positional arguments will be passed to the underlying model's __init__ function
**kwargs**: (`optional`) dict:
**kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading.
Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
If config is None, then **kwargs will be passed to the model.
If said key is *not* present, then kwargs will be used to
override any keys shared with the default configuration for the
given pretrained_model_name_or_path, and only the unshared
key/value pairs will be passed to the model.
Examples::
Examples::
...
@@ -359,7 +371,12 @@ class PreTrainedModel(nn.Module):
...
@@ -359,7 +371,12 @@ class PreTrainedModel(nn.Module):
# Load config
# Load config
if
config
is
None
:
if
config
is
None
:
config
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
config
,
model_kwargs
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
return_unused_args
=
True
,
**
kwargs
)
else
:
model_kwargs
=
kwargs
# Load model
# Load model
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
...
@@ -400,7 +417,7 @@ class PreTrainedModel(nn.Module):
...
@@ -400,7 +417,7 @@ class PreTrainedModel(nn.Module):
archive_file
,
resolved_archive_file
))
archive_file
,
resolved_archive_file
))
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
)
model
=
cls
(
config
,
*
model_args
,
**
model_kwargs
)
if
state_dict
is
None
and
not
from_tf
:
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment