"src/vscode:/vscode.git/clone" did not exist on "f3fac68c55a3f028a364e5611820ccf8a77cf297"
Commit d6be0c7e authored by Myle Ott's avatar Myle Ott
Browse files

Use FP32 for multi-head attention softmax

parent 2d27ae08
...@@ -129,7 +129,7 @@ class MultiheadAttention(nn.Module): ...@@ -129,7 +129,7 @@ class MultiheadAttention(nn.Module):
float('-inf'), float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back ).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v) attn = torch.bmm(attn_weights, v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment