"...static/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "4131d95cef6a08ced14d04df240e3931c25001aa"
Commit 5d29f8e9 authored by VictorSanh's avatar VictorSanh
Browse files

fix bugs

parent a8ad8304
......@@ -274,7 +274,8 @@ class TransformerBlock(nn.Module):
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
if self.output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else:
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(sa_output) == tuple
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
......@@ -329,6 +330,9 @@ class Transformer(nn.Module):
if self.output_attentions:
attentions, hidden_state = hidden_state
all_attentions = all_attentions + (attentions,)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(hidden_state) == tuple
hidden_state = hidden_state[0]
all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment