Unverified Commit 368670ac authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #866 from xanlsh/master

Rework how PreTrainedModel.from_pretrained handles its arguments
parents 6070b554 4fb56c77
......@@ -78,7 +78,7 @@ class PretrainedConfig(object):
self.to_json_file(output_config_file)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
Params:
......@@ -91,20 +91,33 @@ class PretrainedConfig(object):
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
**return_unused_kwargs**: (`optional`) bool:
- If False, then this function returns just the final configuration object.
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs`
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes:
ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
**kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters.
Dictionary of key/value pairs with which to update the configuration object after loading.
- The values in kwargs of any keys which are configuration attributes will be used
to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the `return_unused_kwargs` keyword parameter.
Examples::
>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
>>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
>>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True)
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
>>> assert config.output_attention == True
>>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
>>> foo=False, return_unused_kwargs=True)
>>> assert config.output_attention == True
>>> assert unused_kwargs == {'foo': False}
"""
cache_dir = kwargs.pop('cache_dir', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
......@@ -148,6 +161,9 @@ class PretrainedConfig(object):
kwargs.pop(key, None)
logger.info("Model config %s", config)
if return_unused_kwargs:
return config, kwargs
else:
return config
@classmethod
......@@ -305,7 +321,7 @@ class PreTrainedModel(nn.Module):
torch.save(model_to_save.state_dict(), output_model_file)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
......@@ -322,6 +338,8 @@ class PreTrainedModel(nn.Module):
provided as `config` argument. This loading option is slower than converting the TensorFlow
checkpoint in a PyTorch model using the provided conversion scripts and loading
the PyTorch model afterwards.
**model_args**: (`optional`) Sequence:
All remaning positional arguments will be passed to the underlying model's __init__ function
**config**: an optional configuration for the model to use instead of an automatically loaded configuation.
Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
......@@ -337,8 +355,17 @@ class PreTrainedModel(nn.Module):
**output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
**kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``
Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
- If a configuration is provided with `config`, **kwargs will be directly passed
to the underlying model's __init__ method.
- If a configuration is not provided, **kwargs will be first passed to the pretrained
model configuration class loading function (`PretrainedConfig.from_pretrained`).
Each key of **kwargs that corresponds to a configuration attribute
will be used to override said attribute with the supplied **kwargs value.
Remaining keys that do not correspond to any configuration attribute will
be passed to the underlying model's __init__ function.
Examples::
......@@ -359,7 +386,13 @@ class PreTrainedModel(nn.Module):
# Load config
if config is None:
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True,
**kwargs
)
else:
model_kwargs = kwargs
# Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
......@@ -400,7 +433,7 @@ class PreTrainedModel(nn.Module):
archive_file, resolved_archive_file))
# Instantiate model.
model = cls(config)
model = cls(config, *model_args, **model_kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment