Unverified Commit 3a27ba3d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix opt softmax small nit (#19243)

* fix opt softmax nit

- Use the same logic as 1eb09537550734a783c194e416029cb9bc4cb119 for consistency

* Update src/transformers/models/opt/modeling_opt.py
parent ba9e336f
...@@ -218,11 +218,10 @@ class OPTAttention(nn.Module): ...@@ -218,11 +218,10 @@ class OPTAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
dtype_attn_weights = attn_weights.dtype
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if dtype_attn_weights == torch.float16: if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
else: else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment