Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e179c554
Commit
e179c554
authored
Jul 23, 2019
by
Anish Moorthy
Browse files
Add docs for from_pretrained functions, rename return_unused_args
parent
490ebbdc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
13 deletions
+28
-13
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+28
-13
No files found.
pytorch_transformers/modeling_utils.py
View file @
e179c554
...
...
@@ -91,21 +91,33 @@ class PretrainedConfig(object):
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
**return_unused_kwargs**: (`optional`) bool:
- If False, then this function returns just the final configuration object.
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs`
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes:
ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
**kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters.
Dictionary of key/value pairs with which to update the configuration object after loading.
- The values in kwargs of any keys which are configuration attributes will be used
to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the `return_unused_kwargs` keyword parameter.
Examples::
>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
>>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
>>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True)
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True
, foo=False
)
>>> assert config.output_attention == True
>>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
>>> foo=False, return_unused_kwargs=True)
>>> assert config.output_attention == True
>>> assert unused_kwargs == {'foo': False}
"""
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
return_unused_args
=
kwargs
.
pop
(
'return_unused_args'
,
False
)
return_unused_
kw
args
=
kwargs
.
pop
(
'return_unused_
kw
args'
,
False
)
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
...
...
@@ -149,7 +161,7 @@ class PretrainedConfig(object):
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config %s"
,
config
)
if
return_unused_args
:
if
return_unused_
kw
args
:
return
config
,
kwargs
else
:
return
config
...
...
@@ -326,6 +338,8 @@ class PreTrainedModel(nn.Module):
provided as `config` argument. This loading option is slower than converting the TensorFlow
checkpoint in a PyTorch model using the provided conversion scripts and loading
the PyTorch model afterwards.
**model_args**: (`optional`) Sequence:
All remaning positional arguments will be passed to the underlying model's __init__ function
**config**: an optional configuration for the model to use instead of an automatically loaded configuation.
Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
...
...
@@ -340,17 +354,18 @@ class PreTrainedModel(nn.Module):
configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
**model_args**: (`optional`) Sequence:
All positional arguments will be passed to the underlying model's __init__ function
**kwargs**: (`optional`) dict:
Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
If config is None, then **kwargs will be passed to the model.
If said key is *not* present, then kwargs will be used to
override any keys shared with the default configuration for the
given pretrained_model_name_or_path, and only the unshared
key/value pairs will be passed to the model.
- If a configuration is provided with `config`, **kwargs will be directly passed
to the underlying model's __init__ method.
- If a configuration is not provided, **kwargs will be first passed to the pretrained
model configuration class loading function (`PretrainedConfig.from_pretrained`).
Each key of **kwargs that corresponds to a configuration attribute
will be used to override said attribute with the supplied **kwargs value.
Remaining keys that do not correspond to any configuration attribute will
be passed to the underlying model's __init__ function.
Examples::
...
...
@@ -373,7 +388,7 @@ class PreTrainedModel(nn.Module):
if
config
is
None
:
config
,
model_kwargs
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
cache_dir
=
cache_dir
,
return_unused_args
=
True
,
cache_dir
=
cache_dir
,
return_unused_
kw
args
=
True
,
**
kwargs
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment