"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "bb2e088be70071dc543d3e118375352d5afcfe73"
Commit 95ec1d08 authored by Rémi Louf's avatar Rémi Louf
Browse files

separate inputs into encoder & decoder inputs

parent e4e0ee14
...@@ -130,7 +130,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -130,7 +130,7 @@ class PreTrainedSeq2seq(nn.Module):
return model return model
def forward(self, *inputs, **kwargs): def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
""" The forward pass on a seq2eq depends what we are performing: """ The forward pass on a seq2eq depends what we are performing:
- During training we perform one forward pass through both the encoder - During training we perform one forward pass through both the encoder
...@@ -142,6 +142,11 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -142,6 +142,11 @@ class PreTrainedSeq2seq(nn.Module):
Therefore, we skip the forward pass on the encoder if an argument named Therefore, we skip the forward pass on the encoder if an argument named
`encoder_hidden_state` is passed to this function. `encoder_hidden_state` is passed to this function.
Params:
encoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of encoder input sequence tokens in the vocabulary.
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of decoder input sequence tokens in the vocabulary.
""" """
# Separate the encoder- and decoder- specific kwargs. A kwarg is # Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_` # decoder-specific it the key starts with `decoder_`
...@@ -154,14 +159,14 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -154,14 +159,14 @@ class PreTrainedSeq2seq(nn.Module):
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None) encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(*inputs, **kwargs_encoder) encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0] encoder_hidden_states = encoder_outputs[0]
else: else:
encoder_outputs = () encoder_outputs = ()
# Decode # Decode
kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states
decoder_outputs = self.decoder(**kwargs_decoder) decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment