"git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "d928dd5159ff95a380f8169e43b06d69abe006ec"
Commit bfbe68f0 authored by thomwolf's avatar thomwolf
Browse files

update forward pass

parent 0ef9bc92
...@@ -218,12 +218,14 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -218,12 +218,14 @@ class PreTrainedSeq2seq(nn.Module):
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(*inputs, *kwargs) encoder_outputs = self.encoder(*inputs, *kwargs)
encoder_hidden_states = encoder_outputs[0] encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = (,)
# Decode # Decode
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states
decoder_outputs = self.decoder(**decoder_kwargs) decoder_outputs = self.decoder(**decoder_kwargs)
return decoder_outputs return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq): class Model2Model(PreTrainedSeq2seq):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment