Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
from omegaconf import II
@dataclass
class AdaptiveLossConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg")
ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend")
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
class AdaptiveLoss(FairseqCriterion):
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
@classmethod
def build_criterion(cls, args, task):
if getattr(args, "ddp_backend", None) == "c10d":
raise Exception(
"AdaptiveLoss is not compatible with the c10d "
"version of DistributedDataParallel. Please use "
"`--ddp-backend=no_c10d` instead."
)
return cls(task, args.sentence_avg)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model.decoder, "adaptive_softmax")
and model.decoder.adaptive_softmax is not None
)
adaptive_softmax = model.decoder.adaptive_softmax
net_output = model(**sample["net_input"])
orig_target = model.get_targets(sample, net_output)
nsentences = orig_target.size(0)
orig_target = orig_target.view(-1)
bsz = orig_target.size(0)
logits, target = adaptive_softmax(net_output[0], orig_target)
assert len(target) == len(logits)
loss = net_output[0].new(1 if reduce else bsz).zero_()
for i in range(len(target)):
if target[i] is not None:
assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1)
loss += F.cross_entropy(
logits[i],
target[i],
ignore_index=self.padding_idx,
reduction="sum" if reduce else "none",
)
orig = utils.strip_pad(orig_target, self.padding_idx)
ntokens = orig.numel()
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"loss": loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import utils
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from torch import nn
@register_criterion("composite_loss")
class CompositeLoss(LegacyFairseqCriterion):
"""This is a composite loss that, given a list of model outputs and a list of targets,
computes an average of losses for each output-target pair"""
def __init__(self, args, task):
super().__init__(args, task)
self.underlying_criterion = args.underlying_criterion
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
help='underlying criterion to use for the composite loss')
# fmt: on
@staticmethod
def build_underlying_criterion(args, task):
saved_criterion = args.criterion
args.criterion = args.underlying_criterion
assert saved_criterion != args.underlying_criterion
underlying_criterion = task.build_criterion(args)
args.criterion = saved_criterion
return underlying_criterion
@classmethod
def build_criterion(cls, args, task):
underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
class FakeModel(nn.Module):
def __init__(self, model, net_out, target):
super().__init__()
self.model = model
self.net_out = net_out
self.target = target
def forward(self, **unused):
return self.net_out
def get_normalized_probs(self, net_output, log_probs, sample=None):
return self.model.get_normalized_probs(
net_output, log_probs, sample=sample
)
def get_targets(self, *unused):
return self.target
@property
def decoder(self):
return self.model.decoder
class _CompositeLoss(LegacyFairseqCriterion):
def __init__(self, args, task, underlying_criterion):
super().__init__(args, task)
self.underlying_criterion = underlying_criterion
def forward(self, model, sample, reduce=True):
net_outputs = model(**sample["net_input"])
targets = sample["target"]
bsz = targets[0].size(0)
loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
sample_size = 0
logging_output = {}
for o, t in zip(net_outputs[0], targets):
m = FakeModel(model, (o, net_outputs[1]), t)
sample["target"] = t
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
loss += l
sample_size += ss
loss.div_(len(targets))
sample_size /= len(targets)
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
return underlying_criterion.__class__.aggregate_logging_outputs(
logging_outputs
)
@staticmethod
def reduce_metrics(logging_outputs) -> None:
underlying_criterion.__class__.reduce_metrics(logging_outputs)
return _CompositeLoss(args, task, underlying_criterion)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class CrossEntropyCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg")
@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)
class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(
lprobs,
target,
ignore_index=self.padding_idx,
reduction="sum" if reduce else "none",
)
return loss, loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
# we divide by log(2) to convert the loss from base e to base 2
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
from argparse import Namespace
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import safe_round
@register_criterion("ctc")
class CtcCriterion(FairseqCriterion):
def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe):
super().__init__(task)
self.blank_idx = task.target_dictionary.bos()
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.post_process = remove_bpe if remove_bpe else "letter"
if wer_args is not None:
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args)
dec_args = Namespace()
dec_args.nbest = 1
dec_args.criterion = "ctc"
dec_args.kenlm_model = wer_compute_kenlm
dec_args.lexicon = wer_lexicon
dec_args.beam = 50
dec_args.beam_size_token = min(50, len(task.target_dictionary))
dec_args.beam_threshold = min(50, len(task.target_dictionary))
dec_args.lm_weight = lm_w
dec_args.word_score = ws_w
dec_args.unk_weight = -math.inf
dec_args.sil_weight = 0
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
else:
self.w2l_decoder = None
self.zero_infinity = zero_infinity
self.sentence_avg = sentence_avg
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument(
"--zero-infinity", action="store_true", help="zero inf loss"
)
try:
parser.add_argument(
"--remove-bpe",
"--post-process",
default="letter",
help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)",
)
except:
pass # this option might have been added from eval args
parser.add_argument(
"--wer-args",
type=str,
default=None,
help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \
path to lexicon, lm score, word score",
)
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
lprobs = model.get_normalized_probs(
net_output, log_probs=True
).contiguous() # (T, B, C) from the encoder
if "src_lengths" in sample["net_input"]:
input_lengths = sample["net_input"]["src_lengths"]
else:
non_padding_mask = ~net_output["padding_mask"]
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (sample["target"] != self.pad_idx) & (
sample["target"] != self.eos_idx
)
targets_flat = sample["target"].masked_select(pad_mask)
target_lengths = sample["target_lengths"]
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
lprobs,
targets_flat,
input_lengths,
target_lengths,
blank=self.blank_idx,
reduction="sum",
zero_infinity=self.zero_infinity,
)
ntokens = (
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
)
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
if not model.training:
import editdistance
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
sample["target_label"]
if "target_label" in sample
else sample["target"],
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
decoded = None
if self.w2l_decoder is not None:
decoded = self.w2l_decoder.decode(lp)
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
if decoded is not None and "words" in decoded:
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
else:
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"uer",
lambda meters: safe_round(
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from typing import Any, Dict, List
from fairseq import metrics, utils
from fairseq.dataclass.utils import gen_parser_from_dataclass
from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self, task):
super().__init__()
self.task = task
if hasattr(task, "target_dictionary"):
tgt_dict = task.target_dictionary
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
@classmethod
def add_args(cls, parser):
"""Add criterion-specific arguments to the parser."""
dc = getattr(cls, "__dataclass", None)
if dc is not None:
gen_parser_from_dataclass(parser, dc())
@classmethod
def build_criterion(cls, args, task):
"""Construct a criterion from command-line args."""
# Criterions can override this, but for convenience we also try
# to automatically map argparse.Namespace keys to corresponding
# arguments in the __init__.
init_args = {}
for p in inspect.signature(cls).parameters.values():
if (
p.kind == p.POSITIONAL_ONLY
or p.kind == p.VAR_POSITIONAL
or p.kind == p.VAR_KEYWORD
):
# we haven't implemented inference for these argument types,
# but PRs welcome :)
raise NotImplementedError("{} not supported".format(p.kind))
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
if p.name == "task":
init_args["task"] = task
elif hasattr(args, p.name):
init_args[p.name] = getattr(args, p.name)
elif p.default != p.empty:
pass # we'll use the default value
else:
raise NotImplementedError(
"Unable to infer Criterion arguments, please implement "
"{}.build_criterion".format(cls.__name__)
)
return cls(**init_args)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError
@staticmethod
def aggregate_logging_outputs(
logging_outputs: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"The aggregate_logging_outputs API is deprecated. "
"Please use the reduce_metrics API instead."
)
raise NotImplementedError
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"Criterions should implement the reduce_metrics API. "
"Falling back to deprecated aggregate_logging_outputs API."
)
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
for k, v in agg_logging_outputs.items():
if k in {"nsentences", "ntokens", "sample_size"}:
continue
metrics.log_scalar(k, v)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
class LegacyFairseqCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(task=task)
self.args = args
utils.deprecation_warning(
"Criterions should take explicit arguments instead of an "
"argparse.Namespace object, please update your criterion by "
"extending FairseqCriterion instead of LegacyFairseqCriterion."
)
@classmethod
def build_criterion(cls, args, task):
"""Construct a criterion from command-line args."""
return cls(args, task)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
@register_criterion("label_smoothed_cross_entropy")
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = label_smoothing
self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
parser.add_argument('--report-accuracy', action='store_true',
help='report accuracy metric')
parser.add_argument('--ignore-prefix-size', default=0, type=int,
help='Ignore first N tokens')
# fmt: on
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
return loss, nll_loss
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
"accuracy",
lambda meters: round(
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
@register_criterion("label_smoothed_cross_entropy_with_alignment")
class LabelSmoothedCrossEntropyCriterionWithAlignment(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
super().__init__(task, sentence_avg, label_smoothing)
self.alignment_lambda = alignment_lambda
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
LabelSmoothedCrossEntropyCriterion.add_args(parser)
parser.add_argument(
"--alignment-lambda",
default=0.05,
type=float,
metavar="D",
help="weight for the alignment loss",
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
alignment_loss = None
# Compute alignment loss only for training set and non dummy batches.
if "alignments" in sample and sample["alignments"] is not None:
alignment_loss = self.compute_alignment_loss(sample, net_output)
if alignment_loss is not None:
logging_output["alignment_loss"] = utils.item(alignment_loss.data)
loss += self.alignment_lambda * alignment_loss
return loss, sample_size, logging_output
def compute_alignment_loss(self, sample, net_output):
attn_prob = net_output[1]["attn"][0]
bsz, tgt_sz, src_sz = attn_prob.shape
attn = attn_prob.view(bsz * tgt_sz, src_sz)
align = sample["alignments"]
align_weights = sample["align_weights"].float()
if len(align) > 0:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
loss = -(
(attn[align[:, 1][:, None], align[:, 0][:, None]]).log()
* align_weights[:, None]
).sum()
else:
return None
return loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
alignment_loss_sum = utils.item(
sum(log.get("alignment_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_scalar(
"alignment_loss",
alignment_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
"""
Function to compute the cross entropy loss. The default value of
ignore_index is the same as the default value for F.cross_entropy in
pytorch.
"""
assert logits.size(0) == targets.size(
-1
), "Logits and Targets tensor shapes don't match up"
loss = F.nll_loss(
F.log_softmax(logits, -1, dtype=torch.float32),
targets,
reduction="sum",
ignore_index=ignore_index,
)
return loss
@register_criterion("legacy_masked_lm_loss")
class LegacyMaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
This optionally also computes the next sentence prediction (NSP) loss and
adds it to the overall loss based on the specified args. There are three
cases to consider:
1) Generic MLM training without NSP loss. In this case sentence_targets
and sentence_logits are both None.
2) BERT training without NSP loss. In this case sentence_targets is
not None but sentence_logits is None and we should not be computing
a sentence level loss.
3) BERT training with NSP loss. In this case both sentence_targets and
sentence_logits are not None and we should be computing a sentence
level loss. The weight of the sentence level loss is specified as
an argument.
"""
def __init__(self, task, masked_lm_only, nsp_loss_weight):
super().__init__(task)
self.masked_lm_only = masked_lm_only
self.nsp_loss_weight = nsp_loss_weight
@staticmethod
def add_args(parser):
"""Args for MaskedLM Loss"""
# Default for masked_lm_only is False so as to not break BERT training
parser.add_argument(
"--masked-lm-only",
default=False,
action="store_true",
help="compute MLM loss only",
)
parser.add_argument(
"--nsp-loss-weight",
default=1.0,
type=float,
help="weight for next sentence prediction" " loss (default 1)",
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
lm_logits, output_metadata = model(**sample["net_input"])
# reshape lm_logits from (N,T,C) to (N*T,C)
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
lm_targets = sample["lm_target"].view(-1)
lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx)
# compute the number of tokens for which loss is computed. This is used
# to normalize the loss
ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
loss = lm_loss / ntokens
nsentences = sample["nsentences"]
# nsentences = 0
# Compute sentence loss if masked_lm_only is False
sentence_loss = None
if not self.masked_lm_only:
sentence_logits = output_metadata["sentence_logits"]
sentence_targets = sample["sentence_target"].view(-1)
# This needs to be recomputed due to some differences between
# TokenBlock and BlockPair dataset. This can be resolved with a
# refactor of BERTModel which we will do in the future.
# TODO: Remove this after refactor of BERTModel
nsentences = sentence_targets.size(0)
# Check for logits being none which can happen when remove_heads
# is set to true in the BERT model. Ideally we should set
# masked_lm_only to true in this case, but that requires some
# refactor in the BERT model.
if sentence_logits is not None:
sentence_loss = compute_cross_entropy_loss(
sentence_logits, sentence_targets
)
loss += self.nsp_loss_weight * (sentence_loss / nsentences)
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size = 1
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data,
# sentence loss is not always computed
"sentence_loss": (
(utils.item(sentence_loss.data) if reduce else sentence_loss.data)
if sentence_loss is not None
else 0.0
),
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs)
sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_loss = sum(log.get("loss", 0) for log in logging_outputs)
metrics.log_scalar(
"loss",
agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
sample_size,
round=3,
)
metrics.log_scalar(
"lm_loss",
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
ntokens,
round=3,
)
metrics.log_scalar(
"sentence_loss",
sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0,
nsentences,
round=3,
)
metrics.log_scalar(
"nll_loss",
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
ntokens,
round=3,
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, modules, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("masked_lm")
class MaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
"""
def __init__(self, task, tpu=False):
super().__init__(task)
self.tpu = tpu
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
masked_tokens = sample["target"].ne(self.padding_idx)
sample_size = masked_tokens.int().sum()
# Rare: when all tokens are masked, project all tokens.
# We use torch.where to avoid device-to-host transfers,
# except on CPU where torch.where is not well supported
# (see github.com/pytorch/pytorch/issues/26247).
if self.tpu:
masked_tokens = None # always project all tokens on TPU
elif masked_tokens.device == torch.device("cpu"):
if not masked_tokens.any():
masked_tokens = None
else:
masked_tokens = torch.where(
masked_tokens.any(),
masked_tokens,
masked_tokens.new([True]),
)
logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0]
targets = model.get_targets(sample, [logits])
if masked_tokens is not None:
targets = targets[masked_tokens]
loss = modules.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
reduction="sum",
ignore_index=self.padding_idx,
)
logging_output = {
"loss": loss if self.tpu else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from torch import Tensor
@register_criterion("nat_loss")
class LabelSmoothedDualImitationCriterion(FairseqCriterion):
def __init__(self, task, label_smoothing):
super().__init__(task)
self.label_smoothing = label_smoothing
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument(
"--label-smoothing",
default=0.0,
type=float,
metavar="D",
help="epsilon for label smoothing, 0 means no label smoothing",
)
def _compute_loss(
self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0
):
"""
outputs: batch x len x d_model
targets: batch x len
masks: batch x len
policy_logprob: if there is some policy
depends on the likelihood score as rewards.
"""
def mean_ds(x: Tensor, dim=None) -> Tensor:
return (
x.float().mean().type_as(x)
if dim is None
else x.float().mean(dim).type_as(x)
)
if masks is not None:
outputs, targets = outputs[masks], targets[masks]
if masks is not None and not masks.any():
nll_loss = torch.tensor(0)
loss = nll_loss
else:
logits = F.log_softmax(outputs, dim=-1)
if targets.dim() == 1:
losses = F.nll_loss(logits, targets.to(logits.device), reduction="none")
else: # soft-labels
losses = F.kl_div(logits, targets.to(logits.device), reduction="none")
losses = losses.sum(-1)
nll_loss = mean_ds(losses)
if label_smoothing > 0:
loss = (
nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing
)
else:
loss = nll_loss
loss = loss * factor
return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor}
def _custom_loss(self, loss, name="loss", factor=1.0):
return {"name": name, "loss": loss, "factor": factor}
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
nsentences, ntokens = sample["nsentences"], sample["ntokens"]
# B x T
src_tokens, src_lengths = (
sample["net_input"]["src_tokens"],
sample["net_input"]["src_lengths"],
)
tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"]
outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
losses, nll_loss = [], []
for obj in outputs:
if outputs[obj].get("loss", None) is None:
_losses = self._compute_loss(
outputs[obj].get("out"),
outputs[obj].get("tgt"),
outputs[obj].get("mask", None),
outputs[obj].get("ls", 0.0),
name=obj + "-loss",
factor=outputs[obj].get("factor", 1.0),
)
else:
_losses = self._custom_loss(
outputs[obj].get("loss"),
name=obj + "-loss",
factor=outputs[obj].get("factor", 1.0),
)
losses += [_losses]
if outputs[obj].get("nll_loss", False):
nll_loss += [_losses.get("nll_loss", 0.0)]
loss = sum(l["loss"] for l in losses)
nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0)
# NOTE:
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size = 1
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
for l in losses:
logging_output[l["name"]] = (
utils.item(l["loss"].data / l["factor"])
if reduce
else l[["loss"]].data / l["factor"]
)
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs))
metrics.log_scalar(
"loss", loss / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
for key in logging_outputs[0]:
if key[-5:] == "-loss":
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(
key[:-5],
val / sample_size / math.log(2) if sample_size > 0 else 0.0,
sample_size,
round=3,
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("sentence_prediction")
class SentencePredictionCriterion(FairseqCriterion):
def __init__(self, task, classification_head_name, regression_target):
super().__init__(task)
self.classification_head_name = classification_head_name
self.regression_target = regression_target
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--classification-head-name',
default='sentence_classification_head',
help='name of the classification head to use')
# fmt: on
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
), "model must provide sentence classification head for --criterion=sentence_prediction"
logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
)
targets = model.get_targets(sample, [logits]).view(-1)
sample_size = targets.numel()
if not self.regression_target:
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
logits = logits.view(-1).float()
targets = targets.float()
loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if not self.regression_target:
preds = logits.argmax(dim=1)
logging_output["ncorrect"] = (preds == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
metrics.log_scalar(
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("sentence_ranking")
class SentenceRankingCriterion(FairseqCriterion):
def __init__(self, task, ranking_head_name, save_predictions, num_classes):
super().__init__(task)
self.ranking_head_name = ranking_head_name
if save_predictions is not None:
self.prediction_h = open(save_predictions, "w")
else:
self.prediction_h = None
self.num_classes = num_classes
def __del__(self):
if self.prediction_h is not None:
self.prediction_h.close()
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--save-predictions', metavar='FILE',
help='file to save predictions to')
parser.add_argument('--ranking-head-name',
default='sentence_classification_head',
help='name of the ranking head to use')
# fmt: on
def forward(self, model, sample, reduce=True):
"""Compute ranking loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.ranking_head_name in model.classification_heads
), "model must provide sentence ranking head for --criterion=sentence_ranking"
scores = []
for idx in range(self.num_classes):
score, _ = model(
**sample["net_input{idx}".format(idx=idx + 1)],
classification_head_name=self.ranking_head_name,
)
scores.append(score)
logits = torch.cat(scores, dim=1)
sample_size = logits.size(0)
if "target" in sample:
targets = model.get_targets(sample, [logits]).view(-1)
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
targets = None
loss = torch.tensor(0.0, requires_grad=True)
if self.prediction_h is not None:
preds = logits.argmax(dim=1)
for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())):
if targets is not None:
label = targets[i].item()
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
else:
print("{}\t{}".format(id, pred), file=self.prediction_h)
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if targets is not None:
logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
metrics.log_scalar(
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.logging.meters import safe_round
@register_criterion("wav2vec")
class Wav2vecCriterion(FairseqCriterion):
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
super().__init__(task)
self.infonce = infonce
self.loss_weights = None if loss_weights is None else eval(loss_weights)
self.log_keys = [] if log_keys is None else eval(log_keys)
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--infonce', action='store_true',
help='if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)')
parser.add_argument('--loss-weights', type=str, default=None,
help='weights for additional loss terms (not first one)')
parser.add_argument('--log-keys', type=str, default=None,
help='output keys to log')
# fmt: on
def forward(self, model, sample, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)
weights = None
if hasattr(model, "get_target_weights") and not self.infonce:
weights = model.get_target_weights(target, net_output)
if torch.is_tensor(weights):
weights = weights.float()
losses = []
if self.infonce:
loss = F.cross_entropy(
logits,
target,
reduction="sum" if reduce else "none",
)
else:
loss = F.binary_cross_entropy_with_logits(
logits,
target.float(),
weights,
reduction="sum" if reduce else "none",
)
sample_size = target.numel() if self.infonce else target.long().sum().item()
losses.append(loss.detach().clone())
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, coef in zip(extra_losses, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
losses.append(p)
logging_output = {
"loss": loss.item() if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
if len(losses) > 1:
for i, l in enumerate(losses):
logging_output[f"loss_{i}"] = l.item()
if self.infonce:
with torch.no_grad():
if logits.numel() == 0:
corr = 0
count = 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
logging_output["correct"] = corr
logging_output["count"] = count
if log_pred:
logging_output["logits"] = logits.cpu().numpy()
logging_output["target"] = target.cpu().numpy()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
correct = sum(log.get("correct", 0) for log in logging_outputs)
metrics.log_scalar("_correct", correct)
total = sum(log.get("count", 0) for log in logging_outputs)
metrics.log_scalar("_total", total)
if total > 0:
metrics.log_derived(
"accuracy",
lambda meters: safe_round(
meters["_correct"].sum / meters["_total"].sum, 5
)
if meters["_total"].sum > 0
else float("nan"),
)
builtin_keys = {
"loss",
"ntokens",
"nsentences",
"sample_size",
"correct",
"count",
}
for k in logging_outputs[0]:
if k not in builtin_keys:
val = sum(log.get(k, 0) for log in logging_outputs) / len(
logging_outputs
)
if k.startswith("loss"):
metrics.log_scalar(k, val / sample_size / math.log(2), sample_size)
else:
metrics.log_scalar(k, val, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
from .base_wrapper_dataset import BaseWrapperDataset
from .add_target_dataset import AddTargetDataset
from .append_token_dataset import AppendTokenDataset
from .audio.raw_audio_dataset import FileAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .bucket_pad_length_dataset import BucketPadLengthDataset
from .colorize_dataset import ColorizeDataset
from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset
from .denoising_dataset import DenoisingDataset
from .id_dataset import IdDataset
from .indexed_dataset import (
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MMapIndexedDataset,
)
from .language_pair_dataset import LanguagePairDataset
from .list_dataset import ListDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .lru_cache_dataset import LRUCacheDataset
from .mask_tokens_dataset import MaskTokensDataset
from .monolingual_dataset import MonolingualDataset
from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from .nested_dictionary_dataset import NestedDictionaryDataset
from .noising import NoisingDataset
from .numel_dataset import NumelDataset
from .num_samples_dataset import NumSamplesDataset
from .offset_tokens_dataset import OffsetTokensDataset
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
from .prepend_dataset import PrependDataset
from .prepend_token_dataset import PrependTokenDataset
from .raw_label_dataset import RawLabelDataset
from .replace_dataset import ReplaceDataset
from .resampling_dataset import ResamplingDataset
from .roll_dataset import RollDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sort_dataset import SortDataset
from .strip_token_dataset import StripTokenDataset
from .subsample_dataset import SubsampleDataset
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .shorten_dataset import TruncateDataset, RandomCropDataset
from .multilingual.sampled_multi_dataset import SampledMultiDataset
from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
from .fasta_dataset import FastaDataset, EncodedFastaDataset
from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
__all__ = [
"AddTargetDataset",
"AppendTokenDataset",
"BacktranslationDataset",
"BaseWrapperDataset",
"BucketPadLengthDataset",
"ColorizeDataset",
"ConcatDataset",
"ConcatSentencesDataset",
"CountingIterator",
"DenoisingDataset",
"Dictionary",
"EncodedFastaDataset",
"EpochBatchIterator",
"FairseqDataset",
"FairseqIterableDataset",
"FastaDataset",
"GroupedIterator",
"IdDataset",
"IndexedCachedDataset",
"IndexedDataset",
"IndexedRawTextDataset",
"LanguagePairDataset",
"LeftPadDataset",
"ListDataset",
"LMContextWindowDataset",
"LRUCacheDataset",
"MaskTokensDataset",
"MMapIndexedDataset",
"MonolingualDataset",
"MultiCorpusSampledDataset",
"NestedDictionaryDataset",
"NoisingDataset",
"NumelDataset",
"NumSamplesDataset",
"OffsetTokensDataset",
"PadDataset",
"PrependDataset",
"PrependTokenDataset",
"ReplaceDataset",
"RollDataset",
"FileAudioDataset",
"RawLabelDataset",
"ResamplingDataset",
"RightPadDataset",
"RoundRobinZipDatasets",
"SampledMultiDataset",
"SampledMultiEpochDataset",
"ShardedIterator",
"SortDataset",
"StripTokenDataset",
"SubsampleDataset",
"TokenBlockDataset",
"TransformEosDataset",
"TransformEosLangPairDataset",
"TruncateDataset",
"TruncatedDictionary",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment