Unverified Commit 9bed3554 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[s2s] fix label_smoothed_nll_loss (#6344)

parent 99f73bcc
...@@ -29,17 +29,15 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): ...@@ -29,17 +29,15 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
pad_mask = target.eq(ignore_index) pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0) nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0) smooth_loss.masked_fill_(pad_mask, 0.0)
bs = pad_mask.long().sum()
else: else:
nll_loss = nll_loss.squeeze(-1) nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1)
bs = lprobs.shape[0]
nll_loss = nll_loss.sum() # mean()? Scared to break other math. nll_loss = nll_loss.sum() # mean()? Scared to break other math.
smooth_loss = smooth_loss.sum() smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1) eps_i = epsilon / lprobs.size(-1)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss / bs, nll_loss / bs return loss, nll_loss
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment