"...composable_kernel.git" did not exist on "5b57ab96a8208eec1969a3dcadb555a6246ddb95"
Commit 5d29f8e9 authored by VictorSanh's avatar VictorSanh
Browse files

fix bugs

parent a8ad8304
...@@ -274,7 +274,8 @@ class TransformerBlock(nn.Module): ...@@ -274,7 +274,8 @@ class TransformerBlock(nn.Module):
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask) sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
if self.output_attentions: if self.output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(sa_output) == tuple
sa_output = sa_output[0] sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
...@@ -329,6 +330,9 @@ class Transformer(nn.Module): ...@@ -329,6 +330,9 @@ class Transformer(nn.Module):
if self.output_attentions: if self.output_attentions:
attentions, hidden_state = hidden_state attentions, hidden_state = hidden_state
all_attentions = all_attentions + (attentions,) all_attentions = all_attentions + (attentions,)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(hidden_state) == tuple
hidden_state = hidden_state[0]
all_hidden_states = all_hidden_states + (hidden_state,) all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,) outputs = (hidden_state,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment