Commit 51eb9802 authored by Jiatao Gu's avatar Jiatao Gu Committed by Facebook Github Bot
Browse files

clean up the NAT loss (#921)

Summary:
Clean up the original NAT loss and make it more general to adapt new losses used in NAT models.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/921

Differential Revision: D18610145

Pulled By: MultiPath

fbshipit-source-id: d04dd0fc4047b5f8e332cfe66b1e28cbf39494af
parent 831b6b6e
...@@ -70,8 +70,8 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion): ...@@ -70,8 +70,8 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion):
loss = loss * factor loss = loss * factor
return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor} return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor}
def _custom_loss(self, loss, name="loss"): def _custom_loss(self, loss, name="loss", factor=1.0):
return {"name": name, "loss": loss, "factor": 1} return {"name": name, "loss": loss, "factor": factor}
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -90,59 +90,34 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion): ...@@ -90,59 +90,34 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion):
tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"] tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"]
outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens) outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
losses = [] losses, nll_loss = [], []
if "mask_ins_out" in outputs:
mask_ins_losses = self._compute_loss( for obj in outputs:
outputs["mask_ins_out"], if outputs[obj].get("loss", None) is None:
outputs["mask_ins_tgt"], _losses = self._compute_loss(
outputs["mask_ins_mask"], outputs[obj].get("out"),
name="m_ins-loss", outputs[obj].get("tgt"),
factor=1 if "mask_ins_w" not in outputs else outputs["mask_ins_w"], outputs[obj].get("mask", None),
) outputs[obj].get("ls", 0.0),
losses += [mask_ins_losses] name=obj + '-loss',
factor=outputs[obj].get("factor", 1.0)
if "word_ins_out" in outputs: )
word_ins_losses = self._compute_loss( else:
outputs["word_ins_out"], _losses = self._custom_loss(
outputs["word_ins_tgt"], outputs[obj].get("loss"),
outputs["word_ins_mask"], name=obj + '-loss',
self.args.label_smoothing, factor=outputs[obj].get("factor", 1.0)
name="w_ins-loss", )
factor=1 if "word_ins_w" not in outputs else outputs["word_ins_w"],
)
losses += [word_ins_losses]
nll_loss = word_ins_losses["nll_loss"]
if "word_del_out" in outputs:
word_del_losses = self._compute_loss(
outputs["word_del_out"],
outputs["word_del_tgt"],
outputs["word_del_mask"],
0.01,
name="w_del-loss",
factor=1 if "word_del_w" not in outputs else outputs["word_del_w"],
)
losses += [word_del_losses]
if "length_out" in outputs:
length_losses = self._compute_loss(
outputs["length_out"],
outputs["length_tgt"],
name="len-loss",
factor=1 if "length_w" not in outputs else outputs["length_w"],
)
losses += [length_losses]
for w in outputs: losses += [_losses]
if "-loss" in w: if outputs[obj].get("nll_loss", False):
losses += [self._custom_loss(outputs[w], w)] nll_loss += [_losses.get("nll_loss", 0.0)]
loss = sum(l["loss"] for l in losses) loss = sum(l["loss"] for l in losses)
nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 \
else loss.new_tensor(0)
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss # NOTE:
# we don't need to use sample_size as denominator for the gradient # we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging # here sample_size is just used for logging
sample_size = 1 sample_size = 1
......
...@@ -44,13 +44,17 @@ class CMLMNATransformerModel(NATransformerModel): ...@@ -44,13 +44,17 @@ class CMLMNATransformerModel(NATransformerModel):
prev_output_tokens, encoder_out=encoder_out, tgt_tokens=tgt_tokens prev_output_tokens, encoder_out=encoder_out, tgt_tokens=tgt_tokens
) )
word_ins_mask = prev_output_tokens.eq(self.unk) word_ins_mask = prev_output_tokens.eq(self.unk)
return { return {
"word_ins_out": word_ins_out, "word_ins": {
"word_ins_tgt": word_ins_tgt, "out": word_ins_out, "tgt": word_ins_tgt,
"word_ins_mask": word_ins_mask, "mask": word_ins_mask, "ls": self.args.label_smoothing,
"length_out": length_out, "nll_loss": True
"length_tgt": length_tgt, },
"length_w": self.decoder.length_loss_factor, "length": {
"out": length_out, "tgt": length_tgt,
"factor": self.decoder.length_loss_factor
}
} }
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
......
...@@ -161,9 +161,11 @@ class InsertionTransformerModel(LevenshteinTransformerModel): ...@@ -161,9 +161,11 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
word_ins_masks = prev_output_tokens[:, 1:].ne(self.pad) word_ins_masks = prev_output_tokens[:, 1:].ne(self.pad)
return { return {
"word_ins_out": word_ins_out, "word_ins": {
"word_ins_tgt": word_ins_tgt, "out": word_ins_out, "tgt": word_ins_tgt,
"word_ins_mask": word_ins_masks, "mask": word_ins_masks, "ls": self.args.label_smoothing,
"nll_loss": True
}
} }
def forward_decoder( def forward_decoder(
......
...@@ -127,12 +127,15 @@ class IterNATransformerModel(NATransformerModel): ...@@ -127,12 +127,15 @@ class IterNATransformerModel(NATransformerModel):
word_ins_mask = torch.cat(word_ins_masks, 0) word_ins_mask = torch.cat(word_ins_masks, 0)
return { return {
"word_ins_out": word_ins_out, "word_ins": {
"word_ins_tgt": word_ins_tgt, "out": word_ins_out, "tgt": word_ins_tgt,
"word_ins_mask": word_ins_mask, "mask": word_ins_mask, "ls": self.args.label_smoothing,
"length_out": length_out, "nll_loss": True
"length_tgt": length_tgt, },
"length_w": self.decoder.length_loss_factor, "length": {
"out": length_out, "tgt": length_tgt,
"factor": self.decoder.length_loss_factor
}
} }
......
...@@ -389,15 +389,19 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -389,15 +389,19 @@ class LevenshteinTransformerModel(TransformerModel):
word_del_masks = word_predictions.ne(self.pad) word_del_masks = word_predictions.ne(self.pad)
return { return {
"mask_ins_out": mask_ins_out, "mask_ins": {
"mask_ins_tgt": mask_ins_targets, "out": mask_ins_out, "tgt": mask_ins_targets,
"mask_ins_mask": mask_ins_masks, "mask": mask_ins_masks, "ls": 0.01,
"word_ins_out": word_ins_out, },
"word_ins_tgt": tgt_tokens, "word_ins": {
"word_ins_mask": masked_tgt_masks, "out": word_ins_out, "tgt": tgt_tokens,
"word_del_out": word_del_out, "mask": masked_tgt_masks, "ls": self.args.label_smoothing,
"word_del_tgt": word_del_targets, "nll_loss": True
"word_del_mask": word_del_masks, },
"word_del": {
"out": word_del_out, "tgt": word_del_targets,
"mask": word_del_masks
}
} }
def forward_encoder(self, encoder_inputs): def forward_encoder(self, encoder_inputs):
......
...@@ -102,14 +102,17 @@ class NATransformerModel(TransformerModel): ...@@ -102,14 +102,17 @@ class NATransformerModel(TransformerModel):
) )
return { return {
"word_ins_out": word_ins_out, "word_ins": {
"word_ins_tgt": word_ins_tgt, "out": word_ins_out, "tgt": word_ins_tgt,
"word_ins_mask": word_ins_mask, "mask": word_ins_mask, "ls": self.args.label_smoothing,
"length_out": length_out, "nll_loss": True
"length_tgt": length_tgt, },
"length_w": self.decoder.length_loss_factor, "length": {
"out": length_out, "tgt": length_tgt,
"factor": self.decoder.length_loss_factor
}
} }
def forward_encoder(self, encoder_inputs): def forward_encoder(self, encoder_inputs):
return self.encoder(*encoder_inputs) return self.encoder(*encoder_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment