Unverified Commit 368670ac authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #866 from xanlsh/master

Rework how PreTrainedModel.from_pretrained handles its arguments
parents 6070b554 4fb56c77
...@@ -78,7 +78,7 @@ class PretrainedConfig(object): ...@@ -78,7 +78,7 @@ class PretrainedConfig(object):
self.to_json_file(output_config_file) self.to_json_file(output_config_file)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a PretrainedConfig from a pre-trained model configuration. r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
Params: Params:
...@@ -91,20 +91,33 @@ class PretrainedConfig(object): ...@@ -91,20 +91,33 @@ class PretrainedConfig(object):
**cache_dir**: (`optional`) string: **cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used. configuration should be cached if the standard cache should not be used.
**return_unused_kwargs**: (`optional`) bool:
- If False, then this function returns just the final configuration object.
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs`
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes:
ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
**kwargs**: (`optional`) dict: **kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading. Dictionary of key/value pairs with which to update the configuration object after loading.
Can be used to override selected configuration parameters. - The values in kwargs of any keys which are configuration attributes will be used
to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the `return_unused_kwargs` keyword parameter.
Examples:: Examples::
>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. >>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
>>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` >>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
>>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') >>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True) >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
>>> assert config.output_attention == True >>> assert config.output_attention == True
>>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
>>> foo=False, return_unused_kwargs=True)
>>> assert config.output_attention == True
>>> assert unused_kwargs == {'foo': False}
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
...@@ -148,7 +161,10 @@ class PretrainedConfig(object): ...@@ -148,7 +161,10 @@ class PretrainedConfig(object):
kwargs.pop(key, None) kwargs.pop(key, None)
logger.info("Model config %s", config) logger.info("Model config %s", config)
return config if return_unused_kwargs:
return config, kwargs
else:
return config
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
...@@ -305,7 +321,7 @@ class PreTrainedModel(nn.Module): ...@@ -305,7 +321,7 @@ class PreTrainedModel(nn.Module):
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated) The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
...@@ -322,6 +338,8 @@ class PreTrainedModel(nn.Module): ...@@ -322,6 +338,8 @@ class PreTrainedModel(nn.Module):
provided as `config` argument. This loading option is slower than converting the TensorFlow provided as `config` argument. This loading option is slower than converting the TensorFlow
checkpoint in a PyTorch model using the provided conversion scripts and loading checkpoint in a PyTorch model using the provided conversion scripts and loading
the PyTorch model afterwards. the PyTorch model afterwards.
**model_args**: (`optional`) Sequence:
All remaning positional arguments will be passed to the underlying model's __init__ function
**config**: an optional configuration for the model to use instead of an automatically loaded configuation. **config**: an optional configuration for the model to use instead of an automatically loaded configuation.
Configuration can be automatically loaded when: Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
...@@ -337,8 +355,17 @@ class PreTrainedModel(nn.Module): ...@@ -337,8 +355,17 @@ class PreTrainedModel(nn.Module):
**output_loading_info**: (`optional`) boolean: **output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
**kwargs**: (`optional`) dict: **kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading. Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True`` Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
- If a configuration is provided with `config`, **kwargs will be directly passed
to the underlying model's __init__ method.
- If a configuration is not provided, **kwargs will be first passed to the pretrained
model configuration class loading function (`PretrainedConfig.from_pretrained`).
Each key of **kwargs that corresponds to a configuration attribute
will be used to override said attribute with the supplied **kwargs value.
Remaining keys that do not correspond to any configuration attribute will
be passed to the underlying model's __init__ function.
Examples:: Examples::
...@@ -359,7 +386,13 @@ class PreTrainedModel(nn.Module): ...@@ -359,7 +386,13 @@ class PreTrainedModel(nn.Module):
# Load config # Load config
if config is None: if config is None:
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True,
**kwargs
)
else:
model_kwargs = kwargs
# Load model # Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
...@@ -400,7 +433,7 @@ class PreTrainedModel(nn.Module): ...@@ -400,7 +433,7 @@ class PreTrainedModel(nn.Module):
archive_file, resolved_archive_file)) archive_file, resolved_archive_file))
# Instantiate model. # Instantiate model.
model = cls(config) model = cls(config, *model_args, **model_kwargs)
if state_dict is None and not from_tf: if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu') state_dict = torch.load(resolved_archive_file, map_location='cpu')
...@@ -530,7 +563,7 @@ class PoolerEndLogits(nn.Module): ...@@ -530,7 +563,7 @@ class PoolerEndLogits(nn.Module):
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
hidden states of the first tokens for the labeled span. hidden states of the first tokens for the labeled span.
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the first token for the labeled span: position of the first token for the labeled span:
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
Mask of invalid position such as query and special symbols (PAD, SEP, CLS) Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
1.0 means token should be masked. 1.0 means token should be masked.
...@@ -717,7 +750,7 @@ class SequenceSummary(nn.Module): ...@@ -717,7 +750,7 @@ class SequenceSummary(nn.Module):
- 'attn' => Not implemented now, use multi-head attention - 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction summary_use_proj: Add a projection after the vector extraction
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
summary_first_dropout: Add a dropout before the projection and activation summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation summary_last_dropout: Add a dropout after the projection and activation
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment