"vscode:/vscode.git/clone" did not exist on "db78fac5dfcb4643c5f43579879995b6e6a3dfc6"
Commit 43722c5e authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

convert logits to fp32 for calculating loss in masked_lm_loss criterion

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/568

Differential Revision: D15308483

Pulled By: myleott

fbshipit-source-id: 9d898ce523e46e6b6fb444274f478da0b577b603
parent 5dcc855a
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import math import math
import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
...@@ -22,8 +23,8 @@ def compute_cross_entropy_loss(logits, targets, ignore_index=-100): ...@@ -22,8 +23,8 @@ def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
assert logits.size(0) == targets.size(-1), \ assert logits.size(0) == targets.size(-1), \
"Logits and Targets tensor shapes don't match up" "Logits and Targets tensor shapes don't match up"
loss = F.cross_entropy( loss = F.nll_loss(
logits, F.log_softmax(logits, -1, dtype=torch.float32),
targets, targets,
reduction="sum", reduction="sum",
ignore_index=ignore_index, ignore_index=ignore_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment