"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c6e18de9f86bc543937cfc558e3d28fedf5a08cb"
Commit 4c81960b authored by Rémi Louf's avatar Rémi Louf
Browse files

comment the seq2seq functions

parent 6d6c3267
...@@ -43,13 +43,21 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -43,13 +43,21 @@ class PreTrainedSeq2seq(nn.Module):
@classmethod @classmethod
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the base model classes of the library r""" Instantiates an encoder and a decoder from one or two base classes
from a pre-trained model configuration. of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you need to first set it back in training mode with `model.train()`
Params: Params:
pretrained_model_name_or_path: either: encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
...@@ -84,21 +92,17 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -84,21 +92,17 @@ class PreTrainedSeq2seq(nn.Module):
output_loading_info: (`optional`) boolean: output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
You can specify different kwargs for the decoder by prefixing the key with `decoder_` (e.g. ``decoder_output_attention=True``).
Examples:: Examples::
model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
model = AutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
# Separate the encoder- and decoder- specific kwargs. A kwarg is # Separate the encoder- and decoder- specific kwargs. A kwarg is
...@@ -115,35 +119,49 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -115,35 +119,49 @@ class PreTrainedSeq2seq(nn.Module):
encoder = kwargs.pop('encoder_model', None) encoder = kwargs.pop('encoder_model', None)
if encoder is None: if encoder is None:
kwargs_encoder['is_decoder'] = False kwargs_encoder['is_decoder'] = False
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs_encoder) encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
decoder = kwargs.pop('decoder_model', None) decoder = kwargs.pop('decoder_model', None)
if decoder is None: if decoder is None:
kwargs_decoder['is_decoder'] = True kwargs_decoder['is_decoder'] = True
decoder_model = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
model = cls(encoder, decoder) model = cls(encoder, decoder)
return model return model
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
# Extract decoder inputs """ The forward pass on a seq2eq depends what we are performing:
decoder_kwargs = {}
for key in kwargs.keys(): - During training we perform one forward pass through both the encoder
and decoder;
- During prediction, we perform one forward pass through the encoder,
and then perform several forward passes with the encoder's hidden
state through the decoder to decode a full sequence.
Therefore, we skip the forward pass on the encoder if an argument named
`encoder_hidden_state` is passed to this function.
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
kwargs_decoder = {}
kwargs_encoder = kwargs
for key in kwargs_encoder.keys():
if key.startswith('decoder_'): if key.startswith('decoder_'):
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key) kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key)
# Compute encoder hidden states if needed # Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs.pop('encoder_hidden_states', None) encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(*inputs, *kwargs) encoder_outputs = self.encoder(*inputs, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0] encoder_hidden_states = encoder_outputs[0]
else: else:
encoder_outputs = (,) encoder_outputs = ()
# Decode # Decode
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states
decoder_outputs = self.decoder(**decoder_kwargs) decoder_outputs = self.decoder(**kwargs_decoder)
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
...@@ -161,11 +179,10 @@ class Model2LSTM(PreTrainedSeq2seq): ...@@ -161,11 +179,10 @@ class Model2LSTM(PreTrainedSeq2seq):
# We will create a randomly initilized LSTM model as decoder # We will create a randomly initilized LSTM model as decoder
if 'decoder_config' not in kwargs: if 'decoder_config' not in kwargs:
raise ValueError("To load an LSTM in Seq2seq model, please supply either: " raise ValueError("To load an LSTM in Seq2seq model, please supply either: "
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or " " - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
" - a dictionary of configuration parameters that will be used to initialize a " - a dictionary of configuration parameters that will be used to initialize a"
" torch.nn.LSTM model as `decoder_config` keyword argument. " " torch.nn.LSTM model as `decoder_config` keyword argument. "
" E.g. `decoder_config=\{'input_size': 768, 'hidden_size': 768, 'num_layers': 2\}`") " E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`")
kwargs['decoder_model'] = torch.nn.LSTM(kwarg.pop('decoder_config')) kwargs['decoder_model'] = torch.nn.LSTM(kwargs.pop('decoder_config'))
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs) model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment