Commit 18d27e00 authored by wangwei990215's avatar wangwei990215
Browse files

initial commit

parent 541f4c7a
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from torch import Tensor
@register_criterion("nat_loss")
class LabelSmoothedDualImitationCriterion(FairseqCriterion):
def __init__(self, task, label_smoothing):
super().__init__(task)
self.label_smoothing = label_smoothing
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument(
"--label-smoothing",
default=0.0,
type=float,
metavar="D",
help="epsilon for label smoothing, 0 means no label smoothing",
)
def _compute_loss(
self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0
):
"""
outputs: batch x len x d_model
targets: batch x len
masks: batch x len
policy_logprob: if there is some policy
depends on the likelihood score as rewards.
"""
def mean_ds(x: Tensor, dim=None) -> Tensor:
return (
x.float().mean().type_as(x)
if dim is None
else x.float().mean(dim).type_as(x)
)
if masks is not None:
outputs, targets = outputs[masks], targets[masks]
if masks is not None and not masks.any():
nll_loss = torch.tensor(0)
loss = nll_loss
else:
logits = F.log_softmax(outputs, dim=-1)
if targets.dim() == 1:
losses = F.nll_loss(logits, targets.to(logits.device), reduction="none")
else: # soft-labels
losses = F.kl_div(logits, targets.to(logits.device), reduction="none")
losses = losses.sum(-1)
nll_loss = mean_ds(losses)
if label_smoothing > 0:
loss = (
nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing
)
else:
loss = nll_loss
loss = loss * factor
return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor}
def _custom_loss(self, loss, name="loss", factor=1.0):
return {"name": name, "loss": loss, "factor": factor}
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
nsentences, ntokens = sample["nsentences"], sample["ntokens"]
# B x T
src_tokens, src_lengths = (
sample["net_input"]["src_tokens"],
sample["net_input"]["src_lengths"],
)
tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"]
outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
losses, nll_loss = [], []
for obj in outputs:
if outputs[obj].get("loss", None) is None:
_losses = self._compute_loss(
outputs[obj].get("out"),
outputs[obj].get("tgt"),
outputs[obj].get("mask", None),
outputs[obj].get("ls", 0.0),
name=obj + "-loss",
factor=outputs[obj].get("factor", 1.0),
)
else:
_losses = self._custom_loss(
outputs[obj].get("loss"),
name=obj + "-loss",
factor=outputs[obj].get("factor", 1.0),
)
losses += [_losses]
if outputs[obj].get("nll_loss", False):
nll_loss += [_losses.get("nll_loss", 0.0)]
loss = sum(l["loss"] for l in losses)
nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0)
# NOTE:
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size = 1
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
for l in losses:
logging_output[l["name"]] = (
utils.item(l["loss"].data / l["factor"])
if reduce
else l[["loss"]].data / l["factor"]
)
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs))
metrics.log_scalar(
"loss", loss / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
for key in logging_outputs[0]:
if key[-5:] == "-loss":
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(
key[:-5],
val / sample_size / math.log(2) if sample_size > 0 else 0.0,
sample_size,
round=3,
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("sentence_prediction")
class SentencePredictionCriterion(FairseqCriterion):
def __init__(self, task, classification_head_name, regression_target):
super().__init__(task)
self.classification_head_name = classification_head_name
self.regression_target = regression_target
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--classification-head-name',
default='sentence_classification_head',
help='name of the classification head to use')
# fmt: on
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
), "model must provide sentence classification head for --criterion=sentence_prediction"
logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
)
targets = model.get_targets(sample, [logits]).view(-1)
sample_size = targets.numel()
if not self.regression_target:
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
logits = logits.view(-1).float()
targets = targets.float()
loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if not self.regression_target:
preds = logits.argmax(dim=1)
logging_output["ncorrect"] = (preds == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
metrics.log_scalar(
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("sentence_ranking")
class SentenceRankingCriterion(FairseqCriterion):
def __init__(self, task, ranking_head_name, save_predictions, num_classes):
super().__init__(task)
self.ranking_head_name = ranking_head_name
if save_predictions is not None:
self.prediction_h = open(save_predictions, "w")
else:
self.prediction_h = None
self.num_classes = num_classes
def __del__(self):
if self.prediction_h is not None:
self.prediction_h.close()
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--save-predictions', metavar='FILE',
help='file to save predictions to')
parser.add_argument('--ranking-head-name',
default='sentence_classification_head',
help='name of the ranking head to use')
# fmt: on
def forward(self, model, sample, reduce=True):
"""Compute ranking loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.ranking_head_name in model.classification_heads
), "model must provide sentence ranking head for --criterion=sentence_ranking"
scores = []
for idx in range(self.num_classes):
score, _ = model(
**sample["net_input{idx}".format(idx=idx + 1)],
classification_head_name=self.ranking_head_name,
)
scores.append(score)
logits = torch.cat(scores, dim=1)
sample_size = logits.size(0)
if "target" in sample:
targets = model.get_targets(sample, [logits]).view(-1)
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
targets = None
loss = torch.tensor(0.0, requires_grad=True)
if self.prediction_h is not None:
preds = logits.argmax(dim=1)
for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())):
if targets is not None:
label = targets[i].item()
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
else:
print("{}\t{}".format(id, pred), file=self.prediction_h)
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if targets is not None:
logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
metrics.log_scalar(
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.logging.meters import safe_round
@register_criterion("wav2vec")
class Wav2vecCriterion(FairseqCriterion):
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
super().__init__(task)
self.infonce = infonce
self.loss_weights = None if loss_weights is None else eval(loss_weights)
self.log_keys = [] if log_keys is None else eval(log_keys)
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--infonce', action='store_true',
help='if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)')
parser.add_argument('--loss-weights', type=str, default=None,
help='weights for additional loss terms (not first one)')
parser.add_argument('--log-keys', type=str, default=None,
help='output keys to log')
# fmt: on
def forward(self, model, sample, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)
weights = None
if hasattr(model, "get_target_weights") and not self.infonce:
weights = model.get_target_weights(target, net_output)
if torch.is_tensor(weights):
weights = weights.float()
losses = []
if self.infonce:
loss = F.cross_entropy(
logits,
target,
reduction="sum" if reduce else "none",
)
else:
loss = F.binary_cross_entropy_with_logits(
logits,
target.float(),
weights,
reduction="sum" if reduce else "none",
)
sample_size = target.numel() if self.infonce else target.long().sum().item()
losses.append(loss.detach().clone())
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, coef in zip(extra_losses, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
losses.append(p)
logging_output = {
"loss": loss.item() if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
if len(losses) > 1:
for i, l in enumerate(losses):
logging_output[f"loss_{i}"] = l.item()
if self.infonce:
with torch.no_grad():
if logits.numel() == 0:
corr = 0
count = 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
logging_output["correct"] = corr
logging_output["count"] = count
if log_pred:
logging_output["logits"] = logits.cpu().numpy()
logging_output["target"] = target.cpu().numpy()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
correct = sum(log.get("correct", 0) for log in logging_outputs)
metrics.log_scalar("_correct", correct)
total = sum(log.get("count", 0) for log in logging_outputs)
metrics.log_scalar("_total", total)
if total > 0:
metrics.log_derived(
"accuracy",
lambda meters: safe_round(
meters["_correct"].sum / meters["_total"].sum, 5
)
if meters["_total"].sum > 0
else float("nan"),
)
builtin_keys = {
"loss",
"ntokens",
"nsentences",
"sample_size",
"correct",
"count",
}
for k in logging_outputs[0]:
if k not in builtin_keys:
val = sum(log.get(k, 0) for log in logging_outputs) / len(
logging_outputs
)
if k.startswith("loss"):
metrics.log_scalar(k, val / sample_size / math.log(2), sample_size)
else:
metrics.log_scalar(k, val, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
from .base_wrapper_dataset import BaseWrapperDataset
from .add_target_dataset import AddTargetDataset
from .append_token_dataset import AppendTokenDataset
from .audio.raw_audio_dataset import FileAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .bucket_pad_length_dataset import BucketPadLengthDataset
from .colorize_dataset import ColorizeDataset
from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset
from .denoising_dataset import DenoisingDataset
from .id_dataset import IdDataset
from .indexed_dataset import (
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MMapIndexedDataset,
)
from .language_pair_dataset import LanguagePairDataset
from .list_dataset import ListDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .lru_cache_dataset import LRUCacheDataset
from .mask_tokens_dataset import MaskTokensDataset
from .monolingual_dataset import MonolingualDataset
from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from .nested_dictionary_dataset import NestedDictionaryDataset
from .noising import NoisingDataset
from .numel_dataset import NumelDataset
from .num_samples_dataset import NumSamplesDataset
from .offset_tokens_dataset import OffsetTokensDataset
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
from .prepend_dataset import PrependDataset
from .prepend_token_dataset import PrependTokenDataset
from .raw_label_dataset import RawLabelDataset
from .replace_dataset import ReplaceDataset
from .resampling_dataset import ResamplingDataset
from .roll_dataset import RollDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sort_dataset import SortDataset
from .strip_token_dataset import StripTokenDataset
from .subsample_dataset import SubsampleDataset
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .shorten_dataset import TruncateDataset, RandomCropDataset
from .multilingual.sampled_multi_dataset import SampledMultiDataset
from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
from .fasta_dataset import FastaDataset, EncodedFastaDataset
from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
__all__ = [
"AddTargetDataset",
"AppendTokenDataset",
"BacktranslationDataset",
"BaseWrapperDataset",
"BucketPadLengthDataset",
"ColorizeDataset",
"ConcatDataset",
"ConcatSentencesDataset",
"CountingIterator",
"DenoisingDataset",
"Dictionary",
"EncodedFastaDataset",
"EpochBatchIterator",
"FairseqDataset",
"FairseqIterableDataset",
"FastaDataset",
"GroupedIterator",
"IdDataset",
"IndexedCachedDataset",
"IndexedDataset",
"IndexedRawTextDataset",
"LanguagePairDataset",
"LeftPadDataset",
"ListDataset",
"LMContextWindowDataset",
"LRUCacheDataset",
"MaskTokensDataset",
"MMapIndexedDataset",
"MonolingualDataset",
"MultiCorpusSampledDataset",
"NestedDictionaryDataset",
"NoisingDataset",
"NumelDataset",
"NumSamplesDataset",
"OffsetTokensDataset",
"PadDataset",
"PrependDataset",
"PrependTokenDataset",
"ReplaceDataset",
"RollDataset",
"FileAudioDataset",
"RawLabelDataset",
"ResamplingDataset",
"RightPadDataset",
"RoundRobinZipDatasets",
"SampledMultiDataset",
"SampledMultiEpochDataset",
"ShardedIterator",
"SortDataset",
"StripTokenDataset",
"SubsampleDataset",
"TokenBlockDataset",
"TransformEosDataset",
"TransformEosLangPairDataset",
"TruncateDataset",
"TruncatedDictionary",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment