"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "dd116abfc48e8023bb425c2dd5bd954ee99d7a9c"
Commit f873a3ed authored by Rémi Louf's avatar Rémi Louf
Browse files

the decoder attends to the output of the encoder stack (last layer)

parent 56e2ee4e
...@@ -288,8 +288,8 @@ class BertAttention(nn.Module): ...@@ -288,8 +288,8 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask) self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -334,13 +334,13 @@ class BertLayer(nn.Module): ...@@ -334,13 +334,13 @@ class BertLayer(nn.Module):
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_state is not None: if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask) cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
...@@ -364,8 +364,7 @@ class BertEncoder(nn.Module): ...@@ -364,8 +364,7 @@ class BertEncoder(nn.Module):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
encoder_hidden_state = encoder_hidden_states[i] layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_state, encoder_attention_mask)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if self.output_attentions: if self.output_attentions:
......
...@@ -165,7 +165,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -165,7 +165,7 @@ class PreTrainedSeq2seq(nn.Module):
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None) encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0] encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack*
else: else:
encoder_outputs = () encoder_outputs = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment