Commit 39ac40a9 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2747 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# @package _group_
hydra:
run:
dir: .
defaults:
- task: null
- model: null
- criterion: cross_entropy
- optimizer: null
- lr_scheduler: fixed
- bpe: null
- tokenizer: null
- scoring: null
- generation: null
- common_eval: null
- eval_lm: null
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import importlib
import os
from fairseq import registry
from fairseq.criterions.fairseq_criterion import ( # noqa
FairseqCriterion,
LegacyFairseqCriterion,
)
from omegaconf import DictConfig
(
build_criterion_,
register_criterion,
CRITERION_REGISTRY,
CRITERION_DATACLASS_REGISTRY,
) = registry.setup_registry(
"--criterion", base_class=FairseqCriterion, default="cross_entropy"
)
def build_criterion(cfg: DictConfig, task):
return build_criterion_(cfg, task)
# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("fairseq.criterions." + file_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
from omegaconf import II
@dataclass
class AdaptiveLossConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend")
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
class AdaptiveLoss(FairseqCriterion):
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
@classmethod
def build_criterion(cls, cfg: AdaptiveLossConfig, task):
if cfg.ddp_backend in {"c10d", "pytorch_ddp"}:
raise Exception(
"AdaptiveLoss is not compatible with the PyTorch "
"version of DistributedDataParallel. Please use "
"`--ddp-backend=legacy_ddp` instead."
)
return cls(task, cfg.sentence_avg)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model.decoder, "adaptive_softmax")
and model.decoder.adaptive_softmax is not None
)
adaptive_softmax = model.decoder.adaptive_softmax
net_output = model(**sample["net_input"])
orig_target = model.get_targets(sample, net_output)
nsentences = orig_target.size(0)
orig_target = orig_target.view(-1)
bsz = orig_target.size(0)
logits, target = adaptive_softmax(net_output[0], orig_target)
assert len(target) == len(logits)
loss = net_output[0].new(1 if reduce else bsz).zero_()
for i in range(len(target)):
if target[i] is not None:
assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1)
loss += F.cross_entropy(
logits[i],
target[i],
ignore_index=self.padding_idx,
reduction="sum" if reduce else "none",
)
orig = utils.strip_pad(orig_target, self.padding_idx)
ntokens = orig.numel()
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"loss": loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class CrossEntropyCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)
class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(
lprobs,
target,
ignore_index=self.padding_idx,
reduction="sum" if reduce else "none",
)
return loss, loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
# we divide by log(2) to convert the loss from base e to base 2
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
from argparse import Namespace
from dataclasses import dataclass, field
from omegaconf import II
from typing import Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import post_process
from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round
@dataclass
class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field(
default=False,
metadata={"help": "zero inf loss when source length <= target length"},
)
sentence_avg: bool = II("optimization.sentence_avg")
post_process: str = field(
default="letter",
metadata={
"help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
wer_kenlm_model: Optional[str] = field(
default=None,
metadata={
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
},
)
wer_lexicon: Optional[str] = field(
default=None,
metadata={"help": "lexicon to use with wer_kenlm_model"},
)
wer_lm_weight: float = field(
default=2.0,
metadata={"help": "lm weight to use with wer_kenlm_model"},
)
wer_word_score: float = field(
default=-1.0,
metadata={"help": "lm word score to use with wer_kenlm_model"},
)
wer_args: Optional[str] = field(
default=None,
metadata={
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
},
)
@register_criterion("ctc", dataclass=CtcCriterionConfig)
class CtcCriterion(FairseqCriterion):
def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask):
super().__init__(task)
self.blank_idx = (
task.target_dictionary.index(task.blank_symbol)
if hasattr(task, "blank_symbol")
else 0
)
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.post_process = cfg.post_process
if cfg.wer_args is not None:
(
cfg.wer_kenlm_model,
cfg.wer_lexicon,
cfg.wer_lm_weight,
cfg.wer_word_score,
) = eval(cfg.wer_args)
if cfg.wer_kenlm_model is not None:
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
dec_args = Namespace()
dec_args.nbest = 1
dec_args.criterion = "ctc"
dec_args.kenlm_model = cfg.wer_kenlm_model
dec_args.lexicon = cfg.wer_lexicon
dec_args.beam = 50
dec_args.beam_size_token = min(50, len(task.target_dictionary))
dec_args.beam_threshold = min(50, len(task.target_dictionary))
dec_args.lm_weight = cfg.wer_lm_weight
dec_args.word_score = cfg.wer_word_score
dec_args.unk_weight = -math.inf
dec_args.sil_weight = 0
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
else:
self.w2l_decoder = None
self.zero_infinity = cfg.zero_infinity
self.sentence_avg = cfg.sentence_avg
def get_net_output(self, model, sample):
net_output = model(**sample["net_input"])
return net_output
def get_loss(self, model, sample, net_output, reduce=True):
lprobs = model.get_normalized_probs(
net_output, log_probs=True
).contiguous() # (T, B, C) from the encoder
if "src_lengths" in sample["net_input"]:
input_lengths = sample["net_input"]["src_lengths"]
else:
if net_output["padding_mask"] is not None:
non_padding_mask = ~net_output["padding_mask"]
input_lengths = non_padding_mask.long().sum(-1)
else:
input_lengths = lprobs.new_full(
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
)
pad_mask = (sample["target"] != self.pad_idx) & (
sample["target"] != self.eos_idx
)
targets_flat = sample["target"].masked_select(pad_mask)
if "target_lengths" in sample:
target_lengths = sample["target_lengths"]
else:
target_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
lprobs,
targets_flat,
input_lengths,
target_lengths,
blank=self.blank_idx,
reduction="sum",
zero_infinity=self.zero_infinity,
)
ntokens = (
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
)
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
if not model.training:
import editdistance
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
sample["target_label"]
if "target_label" in sample
else sample["target"],
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
decoded = None
if self.w2l_decoder is not None:
decoded = self.w2l_decoder.decode(lp)
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
if decoded is not None and "words" in decoded:
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
else:
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True):
net_output = self.get_net_output(model, sample)
loss, sample_size, logging_output = self.get_loss(model, sample, net_output, reduce)
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"uer",
lambda meters: safe_round(
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from typing import Any, Dict, List
from fairseq import metrics, utils
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass
from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self, task):
super().__init__()
self.task = task
if hasattr(task, "target_dictionary"):
tgt_dict = task.target_dictionary
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
@classmethod
def add_args(cls, parser):
"""Add criterion-specific arguments to the parser."""
dc = getattr(cls, "__dataclass", None)
if dc is not None:
gen_parser_from_dataclass(parser, dc())
@classmethod
def build_criterion(cls, cfg: FairseqDataclass, task):
"""Construct a criterion from command-line args."""
# arguments in the __init__.
init_args = {}
for p in inspect.signature(cls).parameters.values():
if (
p.kind == p.POSITIONAL_ONLY
or p.kind == p.VAR_POSITIONAL
or p.kind == p.VAR_KEYWORD
):
# we haven't implemented inference for these argument types,
# but PRs welcome :)
raise NotImplementedError("{} not supported".format(p.kind))
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
if p.name == "task":
init_args["task"] = task
elif p.name == "cfg":
init_args["cfg"] = cfg
elif hasattr(cfg, p.name):
init_args[p.name] = getattr(cfg, p.name)
elif p.default != p.empty:
pass # we'll use the default value
else:
raise NotImplementedError(
"Unable to infer Criterion arguments, please implement "
"{}.build_criterion".format(cls.__name__)
)
return cls(**init_args)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError
@staticmethod
def aggregate_logging_outputs(
logging_outputs: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"The aggregate_logging_outputs API is deprecated. "
"Please use the reduce_metrics API instead."
)
raise NotImplementedError
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"Criterions should implement the reduce_metrics API. "
"Falling back to deprecated aggregate_logging_outputs API."
)
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
for k, v in agg_logging_outputs.items():
if k in {"nsentences", "ntokens", "sample_size"}:
continue
metrics.log_scalar(k, v)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
class LegacyFairseqCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(task=task)
self.args = args
utils.deprecation_warning(
"Criterions should take explicit arguments instead of an "
"argparse.Namespace object, please update your criterion by "
"extending FairseqCriterion instead of LegacyFairseqCriterion."
)
@classmethod
def build_criterion(cls, args, task):
"""Construct a criterion from command-line args."""
return cls(args, task)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import re
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.models.hubert import ILSHubertModel
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class HubertCriterionConfig(FairseqDataclass):
pred_masked_weight: float = field(
default=1.0,
metadata={"help": "weight for predictive loss for masked frames"},
)
pred_nomask_weight: float = field(
default=0.0,
metadata={"help": "weight for predictive loss for unmasked frames"},
)
loss_weights: Optional[List[float]] = field(
default=None,
metadata={"help": "weights for additional loss terms (not first one)"},
)
log_keys: List[str] = field(
default_factory=lambda: [],
metadata={"help": "output keys to log"},
)
@register_criterion("hubert", dataclass=HubertCriterionConfig)
class HubertCriterion(FairseqCriterion):
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
super().__init__(task)
self.pred_masked_weight = pred_masked_weight
self.pred_nomask_weight = pred_nomask_weight
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
def get_net_output(self, model, sample):
"""compute the loss for the given sample"""
net_output = model(target_list=sample["target_list"], **sample["net_input"])
return net_output
def get_loss(self, model, sample, net_output, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
loss = 0.
sample_size = 0
logging_output = {}
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = model.get_logits(net_output, True)
targ_m_list = model.get_targets(net_output, True)
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
if self.pred_masked_weight > 0:
if isinstance(model, ILSHubertModel):
if model.weighted_sum:
norm_weights = F.softmax(model.weights, dim=-1)
loss_m_list = norm_weights * torch.stack(loss_m_list, dim=0)
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += targ_m_list[0].numel()
loss_u_list = []
logp_u_list = model.get_logits(net_output, False)
targ_u_list = model.get_targets(net_output, False)
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
if self.pred_nomask_weight > 0:
if model.weighted_sum:
norm_weights = F.softmax(model.weights, dim=-1)
loss_u_list = norm_weights * torch.stack(loss_u_list, dim=0)
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += targ_u_list[0].numel()
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses, names = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
**logging_output,
}
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
def compute_correct(logits):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
for i, logp_m in enumerate(logp_m_list):
corr_m, count_m = compute_correct(logp_m)
logging_output[f"correct_m_{i}"] = corr_m
logging_output[f"count_m_{i}"] = count_m
for i, logp_u in enumerate(logp_u_list):
corr_u, count_u = compute_correct(logp_u)
logging_output[f"correct_u_{i}"] = corr_u
logging_output[f"count_u_{i}"] = count_u
return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True, log_pred=False):
net_output = self.get_net_output(model, sample)
loss, sample_size, logging_output = self.get_loss(model, sample, net_output, reduce, log_pred)
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
if sample_size != ntokens:
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
else:
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
counts = {}
for lk in logging_outputs[0].keys():
if lk.startswith("count_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val)
counts[lk] = val
for lk in logging_outputs[0].keys():
if lk.startswith("loss_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
elif lk.startswith("correct_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError()
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import math
from dataclasses import dataclass, field
from fairseq import pdb
from fairseq import utils, metrics
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.wav2vec_criterion import Wav2vecCriterion, Wav2VecCriterionConfig
from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig
from fairseq.logging.meters import safe_round
@dataclass
class UnispeechCriterionConfig(Wav2VecCriterionConfig, CtcCriterionConfig):
mtlalpha: float = field(
default=0.5, metadata={"help": "loss weight for multitask learning"}
)
@register_criterion('unispeech_criterion', dataclass=UnispeechCriterionConfig)
class UnispeechCriterion(FairseqCriterion):
def __init__(self, cfg:UnispeechCriterionConfig, task):
super().__init__(task)
self.mtlalpha = cfg.mtlalpha
self.w2v_criterion = Wav2vecCriterion(task, cfg.infonce, cfg.loss_weights, cfg.log_keys)
if self.mtlalpha > 0:
self.ctc_criterion = CtcCriterion(cfg, task)
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
if self.mtlalpha > 0.0:
ctc_loss, ctc_sample_size, ctc_logging_output = self.ctc_criterion.get_loss(model, sample, net_output, reduce)
else:
ctc_loss = 0
ctc_sample_size = 0
ctc_logging_output = {}
infonce_loss, infonce_sample_size, infonce_logging_output = self.w2v_criterion.get_loss(model.w2v_encoder.w2v_model, sample, net_output['contrastive_res'], reduce)
loss = self.mtlalpha * ctc_loss + (1.0 - self.mtlalpha) * infonce_loss
sample_size = infonce_sample_size
logging_output = {'loss': loss, 'ntokens': ctc_logging_output['ntokens'], 'nsentences': ctc_logging_output['nsentences'],
'ctc': ctc_logging_output, 'infonce': infonce_logging_output}
return loss, sample_size, logging_output
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return False
@staticmethod
def reduce_metrics(logging_outputs) -> None:
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ctc_loss_sum = utils.item(sum(log['ctc'].get('loss', 0) for log in logging_outputs))
ctc_sample_size = utils.item(sum(log['ctc'].get('sample_size', 0) for log in logging_outputs))
ctc_ntokens = utils.item(sum(log['ctc'].get('ntokens', 0) for log in logging_outputs))
ctc_nsentences = utils.item(sum(log['ctc'].get('nsentences', 0) for log in logging_outputs))
ctras_loss_sum = utils.item(sum(log['infonce'].get('loss', 0) for log in logging_outputs))
ctras_sample_size = utils.item(sum(log['infonce'].get('sample_size', 0) for log in logging_outputs))
ctras_ntokens = utils.item(sum(log['infonce'].get('ntokens', 0) for log in logging_outputs))
ctras_nsentences = utils.item(sum(log['infonce'].get('nsentences', 0) for log in logging_outputs))
metrics.log_scalar(
"loss", loss_sum, 1, round=3)
metrics.log_scalar(
"ctc_loss", ctc_loss_sum / ctc_sample_size / math.log(2), ctc_sample_size, round=3
)
metrics.log_scalar(
"contrastive_loss", ctras_loss_sum / ctras_sample_size / math.log(2), ctras_sample_size, round=3
)
if ctc_sample_size != ctc_ntokens:
metrics.log_scalar(
"nll_loss", ctc_loss_sum / ctc_ntokens / math.log(2), ctc_ntokens, round=3
)
c_errors = sum(log['ctc'].get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log['ctc'].get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log['ctc'].get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log['ctc'].get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log['ctc'].get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"uer",
lambda meters: safe_round(meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_scalar("nsentences", ctras_nsentences)
metrics.log_scalar("ctc_sample_size", ctc_sample_size)
metrics.log_scalar("contrastive_sample_size", ctras_sample_size)
correct = sum(log['infonce'].get("correct", 0) for log in logging_outputs)
metrics.log_scalar("_correct", correct)
total = sum(log['infonce'].get("count", 0) for log in logging_outputs)
metrics.log_scalar("_total", total)
if total > 0:
metrics.log_derived(
"accuracy",
lambda meters: safe_round(meters["_correct"].sum / meters["_total"].sum, 5)
if meters["_total"].sum > 0
else float("nan"),
)
builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'}
for k in logging_outputs[0]['infonce']:
if k not in builtin_keys:
val = sum(log['infonce'].get(k, 0) for log in logging_outputs) / len(logging_outputs)
if k.startswith('loss'):
metrics.log_scalar(k, val / ctras_sample_size / math.log(2), ctras_sample_size)
else:
metrics.log_scalar(k, val, round=3)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.logging.meters import safe_round
from fairseq.utils import is_xla_tensor
@dataclass
class Wav2VecCriterionConfig(FairseqDataclass):
infonce: bool = field(
default=False,
metadata={
"help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)"
},
)
loss_weights: Optional[List[float]] = field(
default=None,
metadata={"help": "weights for additional loss terms (not first one)"},
)
log_keys: List[str] = field(
default_factory=lambda: [],
metadata={"help": "output keys to log"},
)
@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
class Wav2vecCriterion(FairseqCriterion):
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
super().__init__(task)
self.infonce = infonce
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
def get_loss(self, model, sample, net_output, reduce=True, log_pred=False):
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)
weights = None
if hasattr(model, "get_target_weights") and not self.infonce:
weights = model.get_target_weights(target, net_output)
if torch.is_tensor(weights):
weights = weights.float()
losses = []
if self.infonce:
loss = F.cross_entropy(
logits,
target,
reduction="sum" if reduce else "none",
)
else:
loss = F.binary_cross_entropy_with_logits(
logits,
target.float(),
weights,
reduction="sum" if reduce else "none",
)
sample_size = target.numel() if self.infonce else target.long().sum().item()
losses.append(loss.detach().clone())
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, coef in zip(extra_losses, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
losses.append(p)
logging_output = {
"loss": loss.item() if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
if len(losses) > 1:
for i, l in enumerate(losses):
logging_output[f"loss_{i}"] = l.item()
if self.infonce:
with torch.no_grad():
if logits.numel() == 0:
corr = 0
count = 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
logging_output["correct"] = corr
logging_output["count"] = count
if log_pred:
logging_output["logits"] = logits.cpu().numpy()
logging_output["target"] = target.cpu().numpy()
return loss, sample_size, logging_output
def get_net_output(self, model, sample):
net_output = model(**sample["net_input"])
return net_output
def forward(self, model, sample, reduce=True, log_pred=False):
net_output = self.get_net_output(model, sample)
loss, sample_size, logging_output = self.get_loss(model, sample, net_output, reduce, log_pred)
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
correct = sum(log.get("correct", 0) for log in logging_outputs)
metrics.log_scalar("_correct", correct)
total = sum(log.get("count", 0) for log in logging_outputs)
metrics.log_scalar("_total", total)
if total > 0:
metrics.log_derived(
"accuracy",
lambda meters: safe_round(
meters["_correct"].sum / meters["_total"].sum, 5
)
if meters["_total"].sum > 0
else float("nan"),
)
builtin_keys = {
"loss",
"ntokens",
"nsentences",
"sample_size",
"correct",
"count",
}
for k in logging_outputs[0]:
if k not in builtin_keys:
val = sum(log.get(k, 0) for log in logging_outputs) / len(
logging_outputs
)
if k.startswith("loss"):
metrics.log_scalar(k, val / sample_size / math.log(2), sample_size)
else:
metrics.log_scalar(k, val, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import re
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class WavLMCriterionConfig(FairseqDataclass):
pred_masked_weight: float = field(
default=1.0,
metadata={"help": "weight for predictive loss for masked frames"},
)
pred_nomask_weight: float = field(
default=0.0,
metadata={"help": "weight for predictive loss for unmasked frames"},
)
loss_weights: Optional[List[float]] = field(
default=None,
metadata={"help": "weights for additional loss terms (not first one)"},
)
log_keys: List[str] = field(
default_factory=lambda: [],
metadata={"help": "output keys to log"},
)
@register_criterion("wavlm", dataclass=WavLMCriterionConfig)
class WavLMCriterion(FairseqCriterion):
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
super().__init__(task)
self.pred_masked_weight = pred_masked_weight
self.pred_nomask_weight = pred_nomask_weight
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
def get_net_output(self, model, sample):
"""compute the loss for the given sample"""
net_output = model(target_list=sample["target_list"], **sample["net_input"])
return net_output
def get_loss(self, model, sample, net_output, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
loss = 0.
sample_size = 0
logging_output = {}
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = model.get_logits(net_output, True)
targ_m_list = model.get_targets(net_output, True)
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += targ_m_list[0].numel()
loss_u_list = []
logp_u_list = model.get_logits(net_output, False)
targ_u_list = model.get_targets(net_output, False)
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += targ_u_list[0].numel()
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses, names = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
**logging_output,
}
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
def compute_correct(logits):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
for i, logp_m in enumerate(logp_m_list):
corr_m, count_m = compute_correct(logp_m)
logging_output[f"correct_m_{i}"] = corr_m
logging_output[f"count_m_{i}"] = count_m
for i, logp_u in enumerate(logp_u_list):
corr_u, count_u = compute_correct(logp_u)
logging_output[f"correct_u_{i}"] = corr_u
logging_output[f"count_u_{i}"] = count_u
return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True, log_pred=False):
net_output = self.get_net_output(model, sample)
loss, sample_size, logging_output = self.get_loss(model, sample, net_output, reduce, log_pred)
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
if sample_size != ntokens:
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
else:
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
builtin_keys = {
"loss",
"nll_loss",
"ppl",
"ntokens",
"nsentences",
"sample_size",
"correct",
"count",
}
counts = {}
for lk in logging_outputs[0].keys():
if lk.startswith("count_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val)
counts[lk] = val
builtin_keys.add(lk)
for lk in logging_outputs[0].keys():
if lk.startswith("loss_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
elif lk.startswith("correct_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
elif lk not in builtin_keys:
val = sum(log.get(lk, 0) for log in logging_outputs) / len(
logging_outputs
)
metrics.log_scalar(lk, val, round=3)
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError()
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
from .base_wrapper_dataset import BaseWrapperDataset
from .add_target_dataset import AddTargetDataset
from .audio.raw_audio_dataset import FileAudioDataset
from .audio.hubert_dataset import HubertDataset
from .audio.utterance_mixing_dataset import UtteranceMixingDataset
from .concat_dataset import ConcatDataset
from .id_dataset import IdDataset
from .resampling_dataset import ResamplingDataset
from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
from .monolingual_dataset import MonolingualDataset
__all__ = [
"AddTargetDataset",
"ConcatDataset",
"CountingIterator",
"Dictionary",
"EpochBatchIterator",
"FairseqDataset",
"FairseqIterableDataset",
"FileAudioDataset",
"GroupedIterator",
"HubertDataset",
"IdDataset",
"ResamplingDataset",
"ShardedIterator",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset, data_utils
class AddTargetDataset(BaseWrapperDataset):
def __init__(
self,
dataset,
labels,
pad,
eos,
batch_targets,
process_label=None,
add_to_input=False,
):
super().__init__(dataset)
self.labels = labels
self.batch_targets = batch_targets
self.pad = pad
self.eos = eos
self.process_label = process_label
self.add_to_input = add_to_input
def get_label(self, index):
return (
self.labels[index]
if self.process_label is None
else self.process_label(self.labels[index])
)
def __getitem__(self, index):
item = self.dataset[index]
item["label"] = self.get_label(index)
return item
def size(self, index):
sz = self.dataset.size(index)
own_sz = len(self.get_label(index))
return (sz, own_sz)
def collater(self, samples):
collated = self.dataset.collater(samples)
if len(collated) == 0:
return collated
indices = set(collated["id"].tolist())
target = [s["label"] for s in samples if s["id"] in indices]
if self.batch_targets:
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
collated["ntokens"] = collated["target_lengths"].sum().item()
else:
collated["ntokens"] = sum([len(t) for t in target])
collated["target"] = target
if self.add_to_input:
eos = target.new_full((target.size(0), 1), self.eos)
collated["target"] = torch.cat([target, eos], dim=-1).long()
collated["net_input"]["prev_output_tokens"] = torch.cat(
[eos, target], dim=-1
).long()
collated["ntokens"] += target.size(0)
return collated
from pathlib import Path
from typing import BinaryIO, Optional, Tuple, Union, List
import numpy as np
import torch
import io
import json
import librosa
import scipy
import soundfile as sf
SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
def preemphasis(x, preemph):
return scipy.signal.lfilter([1, -preemph], [1], x)
def mulaw_encode(x, mu):
mu = mu - 1
fx = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
return np.floor((fx + 1) / 2 * mu + 0.5)
def mulaw_decode(y, mu):
mu = mu - 1
x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1)
return x
def _convert_to_mono(
waveform: torch.FloatTensor, sample_rate: int
) -> torch.FloatTensor:
if waveform.shape[0] > 1:
try:
import torchaudio.sox_effects as ta_sox
except ImportError:
raise ImportError(
"Please install torchaudio to convert multi-channel audios"
)
effects = [['channels', '1']]
return ta_sox.apply_effects_tensor(waveform, sample_rate, effects)[0]
return waveform
def convert_to_mono(waveform: np.ndarray, sample_rate: int) -> np.ndarray:
if waveform.shape[0] > 1:
_waveform = torch.from_numpy(waveform)
return _convert_to_mono(_waveform, sample_rate).numpy()
return waveform
def get_waveform(
path_or_fp: Union[str, BinaryIO], normalization=True, mono=True,
frames=-1, start=0, always_2d=True
) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): Normalize values to [-1, 1] (Default: True)
mono (bool): convert multi-channel audio to mono-channel one
frames (int): the number of frames to read. (-1 for reading all)
start (int): Where to start reading. A negative value counts from the end.
always_2d (bool): always return 2D array even for mono-channel audios
Returns:
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
sample_rate (float): sample rate
"""
if isinstance(path_or_fp, str):
ext = Path(path_or_fp).suffix
if ext not in SF_AUDIO_FILE_EXTENSIONS:
raise ValueError(f"Unsupported audio format: {ext}")
try:
import soundfile as sf
except ImportError:
raise ImportError(
"Please install soundfile to load WAV/FLAC/OGG Vorbis audios"
)
waveform, sample_rate = sf.read(
path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
)
waveform = waveform.T # T x C -> C x T
if mono and waveform.shape[0] > 1:
waveform = convert_to_mono(waveform, sample_rate)
if not normalization:
waveform *= 2 ** 15 # denormalized to 16-bit signed integers
if not always_2d:
waveform = waveform.squeeze(axis=0)
return waveform, sample_rate
def _get_kaldi_fbank(
waveform: np.ndarray, sample_rate: int, n_bins=80
) -> Optional[np.ndarray]:
"""Get mel-filter bank features via PyKaldi."""
try:
from kaldi.feat.mel import MelBanksOptions
from kaldi.feat.fbank import FbankOptions, Fbank
from kaldi.feat.window import FrameExtractionOptions
from kaldi.matrix import Vector
mel_opts = MelBanksOptions()
mel_opts.num_bins = n_bins
frame_opts = FrameExtractionOptions()
frame_opts.samp_freq = sample_rate
opts = FbankOptions()
opts.mel_opts = mel_opts
opts.frame_opts = frame_opts
fbank = Fbank(opts=opts)
features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
return features
except ImportError:
return None
def _get_torchaudio_fbank(
waveform: np.ndarray, sample_rate, n_bins=80
) -> Optional[np.ndarray]:
"""Get mel-filter bank features via TorchAudio."""
try:
import torchaudio.compliance.kaldi as ta_kaldi
waveform = torch.from_numpy(waveform)
features = ta_kaldi.fbank(
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
)
return features.numpy()
except ImportError:
return None
def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
waveform, sample_rate = get_waveform(path_or_fp, normalization=False)
features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
if features is None:
features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
return features
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def is_sf_audio_data(data: bytes) -> bool:
is_wav = (data[0] == 82 and data[1] == 73 and data[2] == 70)
is_flac = (data[0] == 102 and data[1] == 76 and data[2] == 97)
is_ogg = (data[0] == 79 and data[1] == 103 and data[2] == 103)
return is_wav or is_flac or is_ogg
def read_from_stored_zip(zip_path: str, offset: int, file_size: int) -> bytes:
with open(zip_path, "rb") as f:
f.seek(offset)
data = f.read(file_size)
return data
def parse_path(path: str) -> Tuple[str, List[int]]:
"""Parse data path which is either a path to
1. a .npy/.wav/.flac/.ogg file
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
Args:
path (str): the data path to parse
Returns:
file_path (str): the file path
slice_ptr (list of int): empty in case 1;
byte offset and length for the slice in case 2
"""
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
_path, slice_ptr = path, []
else:
_path, *slice_ptr = path.split(":")
if not Path(_path).is_file():
raise FileNotFoundError(f"File not found: {_path}")
assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
slice_ptr = [int(i) for i in slice_ptr]
return _path, slice_ptr
def _group_to_batches_by_utters(buffer, sorted_idx_len_pair, batch_size):
batch_list = []
single_batch = []
for idx_len_pair in sorted_idx_len_pair:
single_batch.append(buffer[idx_len_pair[0]])
if len(single_batch) == batch_size:
batch_list.append(single_batch)
single_batch = []
if len(single_batch) > 0:
batch_list.append(single_batch)
return batch_list
def _group_to_batches_by_frames(buffer, sorted_idx_len_pair, batch_size):
batch_list = []
single_batch = []
frame_num_padded = 0
first_utt_len = sorted_idx_len_pair[0][1]
max_sentence = batch_size // first_utt_len // 8 * 8
for idx_len_pair in sorted_idx_len_pair:
if max_sentence == 0:
max_sentence = 8
frame_num_padded += first_utt_len
if frame_num_padded > batch_size or len(single_batch) == max_sentence:
if len(single_batch) > 0:
batch_list.append(single_batch)
single_batch = []
first_utt_len = idx_len_pair[1]
frame_num_padded = first_utt_len
max_sentence = batch_size // first_utt_len // 8 * 8
single_batch.append(buffer[idx_len_pair[0]])
if len(single_batch) > 0:
batch_list.append(single_batch)
return batch_list
def _group_to_batches_by_frame_x_label(buffer, sorted_idx_len_pair, batch_size):
batch_list = []
single_batch = []
frame_num_padded = 0
max_lab_len = sorted_idx_len_pair[0][2] + 1
max_utt_len = sorted_idx_len_pair[0][1]
for idx_len_pair in sorted_idx_len_pair:
if max_lab_len < idx_len_pair[2] + 1:
max_lab_len = idx_len_pair[2] + 1
frame_num_padded = max_utt_len * max_lab_len * (len(single_batch) )
if frame_num_padded > batch_size:
if len(single_batch) > 0:
batch_list.append(single_batch)
single_batch = []
max_utt_len = idx_len_pair[1]
max_lab_len = idx_len_pair[2] + 1
single_batch.append(buffer[idx_len_pair[0]])
if len(single_batch) > 0:
batch_list.append(single_batch)
return batch_list
class DataParser():
def __init__(self):
super().__init__()
def _parse_data(self, data, data_type):
if data_type.lower() == 'audio':
parsed_data = self._parse_audio_data(data)
elif data_type.lower() == 'info':
parsed_data = self._parse_json_data(data)
elif data_type.lower() == "feature":
parsed_data = self._parse_feat_data(data)
else:
parsed_data = self._parse_string_data(data)
return parsed_data
def _parse_audio_data(self, data):
byte_stream = io.BytesIO(data)
with sf.SoundFile(byte_stream, 'r') as f:
samples = f.read()
return samples
def _parse_json_data(self, data):
str_data = str(data, 'utf-8')
json_data = json.loads(str_data)
return json_data
def _parse_string_data(self, data):
str_data = str(data, 'utf-8')
return str_data
def _parse_feat_data(self, data):
feat = np.frombuffer(data, dtype=np.float32)
feat = feat.reshape(-1, 80)
return feat
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pdb
import logging
import os
import sys
import json
import soundfile as sf
import numpy as np
import torch
import torch.nn.functional as F
from .. import FairseqDataset, data_utils
from fairseq.tokenizer import char_tokenizer
from fairseq.data.audio.audio_utils import _group_to_batches_by_frames, _group_to_batches_by_utters, _group_to_batches_by_frame_x_label, DataParser
ENDIAN = 'little'
logger = logging.getLogger(__name__)
class ChunkAudioDataset(torch.utils.data.IterableDataset, FairseqDataset):
def __init__(
self,
chunk_data_file,
chunk_data_path=None,
chunk_trans_path=None,
max_sample_size=None,
min_sample_size=None,
max_tokens=None,
pad=False,
normalize=False,
subset=None,
shuffle=True,
shard=True,
label=False,
dictionary=None,
feature="audio",
mean_file=None,
invstd_file=None,
batch_criterion="frame"
):
self._data_path = chunk_data_path
self._data_file = chunk_data_file
self._trans_path = chunk_trans_path
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.min_sample_size = min_sample_size
self.max_tokens = max_tokens
self.pad = pad
self.shuffle = shuffle
self.shard = shard
self.normalize = normalize
self.label = label
self.dictionary=dictionary
self.feature = feature
if mean_file is not None:
self.mean = np.fromfile(mean_file, sep='\n')
else:
self.mean = None
if invstd_file is not None:
self.invstd = np.fromfile(invstd_file, sep='\n')
else:
self.invstd = None
with open(self._data_file) as f:
self._chunk_list = json.load(f)['fileInfo']
if self._data_path is None:
self._data_path = os.path.dirname(self._data_file)
if self._trans_path is None:
self._trans_path = os.path.dirname(self._data_file)
self._chunk_num = len(self._chunk_list)
self._example_num = 0
self._dist_size = 1
self._dist_rank = 0
self.end_of_epoch = False
for chunk in self._chunk_list:
self._example_num += int(chunk['count'])
logger.info(f"Open dataset {self._data_file}, total example count {self._example_num}")
self.subset = subset
self.parser = DataParser()
self._buffer_size = 3000
self._batch_criterion = batch_criterion
self._example_buffer = []
self._batch_buffer = []
self._first_iteration = True
self.iterable = None
def __len__(self):
return self._example_num
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
offset = self._dist_rank
skip = self._dist_size
else:
offset = self._dist_size * worker_info.id + self._dist_rank
skip = self._dist_size * worker_info.num_workers
#print(self._chunk_list[13])
if self.shard:
self._sharded_list = list(self._chunk_list[offset::skip])
value = len(self._chunk_list) % self._dist_size
if value !=0 and self._dist_rank >= value:
if worker_info is None or worker_info.id == worker_info.num_workers - 1:
np.random.seed(self._dist_rank)
pad_chunk = np.random.choice(self._chunk_list)
self._sharded_list.append(pad_chunk)
else:
self._sharded_list = self._chunk_list
self.iterable = iter(self._chunk_deserializer())
#print("{}/{} worker init in gpu {}, sharded data {}/{}".format(worker_info.id, worker_info.num_workers, self._dist_rank, len(self._sharded_list), len(self._chunk_list)))
return self
def reset(self, world_size=1, world_rank=0):
#print("Reset Dataset")
self._example_buffer = []
self._batch_buffer = []
self._first_iteration = True
self._dist_size = world_size
self._dist_rank = world_rank
np.random.seed(self.epoch)
if self.shuffle:
np.random.shuffle(self._chunk_list)
def set_epoch(self, epoch):
self.epoch = epoch
def __next__(self):
return self._dynamicbatcher()
def _read_chunk(self, file_path, chunk_type, chunk_size):
example_list = []
with open(file_path, 'rb') as f:
target_type = f.read(len(chunk_type.encode())).decode()
if chunk_type.lower() != target_type.lower():
raise ValueError(
'Taget type is not expected in {}, expected {}, but got {}'
.format(file_path, chunk_type, target_type))
version_number = int.from_bytes(f.read(4), byteorder=ENDIAN)
for i in range(chunk_size):
example_index = int.from_bytes(f.read(4), byteorder=ENDIAN)
if example_index != i:
raise ValueError(
'The example index is corrupted in {}, \
expected {}, but got {}'.format(
file_path, i, example_index))
data_size = int.from_bytes(f.read(4), byteorder=ENDIAN)
data = f.read(data_size)
example_list.append(data)
return example_list
def _chunk_deserializer(self):
try:
iterator = iter(self._sharded_list)
chunk = next(iterator)
while True:
chunk_type = ['info', self.feature]
if self.label:
chunk_type.append('transcription')
chunk_name = chunk['name']
chunk_size = int(chunk['count'])
example_dict = {}
for extension in chunk_type:
if extension == 'transcription':
file_path = os.path.join(self._trans_path, chunk_name+'.transcription')
else:
file_path = os.path.join(self._data_path, chunk_name+'.'+extension)
example_dict[extension] = self._read_chunk(file_path, extension, chunk_size)
example_lens = [len(example_dict[x]) for x in chunk_type]
if not all(x == chunk_size for x in example_lens):
error_msg = 'Chunk size is not consistent in chunk {}'.format(chunk_name)
raise ValueError(error_msg)
for i in range(chunk_size):
one_example = {}
for extension in chunk_type:
one_example[extension] = self.parser._parse_data(example_dict[extension][i], extension)
if self.subset is not None and self.subset not in one_example['info']['corpusname']:
break
if 'transcription' in one_example:
one_example['y'] = self.dictionary.encode_line(
one_example['transcription'].upper(), line_tokenizer=char_tokenizer,
add_if_not_exist=False,
append_eos=False
)
if self.feature not in one_example:
continue
yield one_example
chunk = next(iterator)
except StopIteration:
return
def _fill_buffer_by_length(self, buffer, length):
try:
i = 0
while i < length:
example = next(self.iterable)
x_len = example[self.feature].shape[0]
if self.pad and self.max_sample_size is not None and x_len > self.max_sample_size:
continue
if self.min_sample_size is not None and x_len < self.min_sample_size:
continue
buffer.append(example)
i += 1
except StopIteration:
pass
def _create_batch_list(self, example_list):
idx_len_pair = []
for idx in range(len(example_list)):
uttlen = len(example_list[idx][self.feature])
if 'y' in example_list[idx]:
target_len = len(example_list[idx]['y'])
else:
target_len = 1
idx_len_pair.append((idx, uttlen, target_len))
sorted_idx_len_pair = sorted(idx_len_pair, key=lambda var: var[1], reverse=self.pad)
if self._batch_criterion == "frame":
group_batches_fn = _group_to_batches_by_frames
elif self._batch_criterion == "utterance":
group_batches_fn = _group_to_batches_by_utters
elif self._batch_criterion == "frame_x_label":
group_batches_fn = _group_to_batches_by_frame_x_label
else:
raise ValueError("Only support for grouping batches by 'frame', 'utterance', 'frame_x_label'")
batch_list = group_batches_fn(
self._example_buffer, sorted_idx_len_pair, self.max_tokens)
if self.shuffle:
np.random.shuffle(batch_list)
return batch_list
def _dynamicbatcher(self):
if self._first_iteration:
self._first_iteration = False
self._fill_buffer_by_length(self._example_buffer, self._buffer_size)
if self.shuffle:
np.random.shuffle(self._example_buffer)
if not self._batch_buffer and not self._example_buffer:
raise StopIteration
if not self._batch_buffer:
self._batch_buffer = self._create_batch_list(self._example_buffer)
self._example_buffer = []
single_batch = self._batch_buffer.pop()
self._fill_buffer_by_length(self._example_buffer, len(single_batch))
if self.feature == "audio":
sources = [self.postprocess(torch.from_numpy(s[self.feature])).float() for s in single_batch]
else:
sources = [torch.from_numpy(self.mvn(s[self.feature])).float() for s in single_batch]
infos = [s['info'] for s in single_batch]
ids = torch.LongTensor(list(range(len(single_batch))))
if self.label:
target = [s['y'] for s in single_batch]
return {'info': infos, 'id': ids,'source': sources, "target": target}
return {'info': infos, 'id': ids,'source': sources}
def collater(self, samples):
samples = samples[0]
if len(samples["source"]) == 0:
return {}
sources = samples['source']
sizes = [len(s) for s in sources]
if self.pad:
target_size = min(max(sizes), self.max_sample_size)
else:
target_size = min(min(sizes), self.max_sample_size)
if self.feature == "audio":
collated_sources = sources[0].new_zeros(len(sources), target_size)
else:
collated_sources = sources[0].new_zeros(len(sources), target_size, 80)
padding_mask = (
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
)
for i, (source, size) in enumerate(zip(sources, sizes)):
diff = size - target_size
if diff == 0:
collated_sources[i] = source
elif diff < 0:
assert self.pad
if self.feature == "audio":
collated_sources[i] = torch.cat(
[source, source.new_full((-diff,), 0.0)]
)
else:
collated_sources[i] = torch.cat(
[source, source.new_full((-diff, 80), 0.0)]
)
padding_mask[i, diff:] = True
else:
collated_sources[i] = self.crop_to_max_size(source, target_size)
input = {"source": collated_sources}
if self.pad:
input["padding_mask"] = padding_mask
collated = {"info": samples["info"], "id": samples["id"], "net_input": input}
if not self.label:
return collated
target = samples['target']
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
collated["ntokens"] = collated["target_lengths"].sum().item()
collated["target"] = target
return collated
def postprocess(self, feats):
if feats.dim() == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
if self.normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
def mvn(self, feats):
feats = (feats - self.mean) * self.invstd
return feats
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav
start = np.random.randint(0, diff+1)
end = size - diff + start
return wav[start:end]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
import io
import pdb
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import data_utils
from .. import FairseqDataset
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
from fairseq.data.audio.audio_utils import (
parse_path,
read_from_stored_zip,
is_sf_audio_data,
mulaw_encode,
preemphasis,
)
logger = logging.getLogger(__name__)
class FeatsAudioDataset(RawAudioDataset):
def __init__(
self,
manifest_path,
sample_rate,
input_feature="mfcc",
output_feature="mfcc",
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
pad=False,
normalize=normalize
)
self.chunk_names = []
self.chunk_indices = []
self.fnames = []
self.skipped = []
self.speakers = []
self.input_feature = input_feature
self.output_feature = output_feature
self.speaker_dict = {}
speaker_count = 0
skipped = 0
count = 0
sizes = []
self.skipped_indices = set()
with open(manifest_path, "r") as f:
self.root_dir = f.readline().strip()
for i, line in enumerate(f):
items = line.strip().split("\t")
#assert len(items) == 2, line
sz = int(items[1])
if self.input_feature != "wav":
sz = int(sz/self.sample_rate*100)
if min_sample_size is not None and sz < min_sample_size:
skipped += 1
self.skipped.append(i)
self.skipped_indices.add(i)
continue
fname = items[0].split(":")
if len(fname) > 1:
if len(self.chunk_names) == 0 or fname[0] != self.chunk_names[-1]:
self.chunk_names.append(fname[0])
self.chunk_indices.append(len(self.fnames))
self.fnames.append(items[0])
if len(items) > 2:
speaker = int(items[2])
else:
speaker = int(items[0].split("/")[-1].split("-")[0])
if speaker not in self.speaker_dict:
self.speaker_dict[speaker] = speaker_count
speaker_count += 1
self.speakers.append(self.speaker_dict[speaker])
sizes.append(sz)
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
self.sizes = np.array(sizes, dtype=np.int64)
try:
import pyarrow
self.fnames = pyarrow.array(self.fnames)
except:
logger.debug(
"Could not create a pyarrow array. Please install pyarrow for better performance"
)
pass
def get_mfcc(self, wav, sample_rate=16000, normalize=True):
try:
import torchaudio
import torchaudio.compliance.kaldi as ta_kaldi
with torch.no_grad():
x = torch.from_numpy(wav).float()
x = x.view(1, -1)
mfccs = ta_kaldi.mfcc(
waveform=x,
sample_frequency=sample_rate,
use_energy=False,
) # (time, freq)
mfccs = mfccs.transpose(0, 1) # (freq, time)
deltas = torchaudio.functional.compute_deltas(mfccs)
ddeltas = torchaudio.functional.compute_deltas(deltas)
concat = torch.cat([mfccs, deltas, ddeltas], dim=0)
concat = concat.transpose(0, 1).contiguous() # (freq, time)
if normalize:
mean = concat.mean(dim=0)
std = concat.std(dim=0)
concat = (concat - mean) / std
return concat
except ImportError:
return None
def get_logmel(self, wav, sample_rate=16000, preemph=0.97, n_fft=2048, n_mels=80, hop_length=160,
win_length=400, fmin=50, top_db=80, bits=8, offset=0.0, duration=None):
wav = wav / np.abs(wav).max() * 0.999
try:
import librosa
mel = librosa.feature.melspectrogram(preemphasis(wav, preemph),
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, hop_length=hop_length,
win_length=win_length, fmin=fmin, power=1)
logmel = librosa.amplitude_to_db(mel, top_db=top_db)
logmel = logmel / top_db + 1
logmel = torch.from_numpy(logmel).transpose(0, 1)
return logmel
except ImportError:
return None
def get_fbank(self, wav, n_bins=80, sample_rate=16000, normalize=True):
try:
import torchaudio.compliance.kaldi as ta_kaldi
x = torch.from_numpy(wav).float()
x = x.view(1, -1)
features = ta_kaldi.fbank(
x, num_mel_bins=n_bins, sample_frequency=sample_rate
)
if normalize:
mean = features.mean(dim=0)
std = features.std(dim=0)
features = (features - mean) / std
return features
except ImportError:
return None
def mulaw_encode(self, wav):
wav = wav / np.abs(wav).max() * 0.999
wav = mulaw_encode(wav, mu=2**8)
return wav
def __getitem__(self, index):
import soundfile as sf
path_or_fp = os.path.join(self.root_dir, str(self.fnames[index]))
_path, slice_ptr = parse_path(path_or_fp)
if len(slice_ptr) == 2:
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
assert is_sf_audio_data(byte_data)
path_or_fp = io.BytesIO(byte_data)
wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
if self.input_feature == "wav":
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
elif self.input_feature == "fbank":
feats = self.get_fbank(wav, n_bins=80, sample_rate=curr_sample_rate)
elif self.input_feature == "mfcc":
feats = self.get_mfcc(wav, sample_rate=curr_sample_rate)
elif self.input_feature == "logmel":
feats = self.get_logmel(wav, sample_rate=curr_sample_rate)
elif self.input_feature == "mulaw":
feats = self.mulaw_encode(wav)
feats = torch.from_numpy(feats).long()
else:
raise ValueError("Unknown extra features {}".format(self.input_feature))
if self.output_feature == self.input_feature:
target = feats
elif self.output_feature == "wav":
target = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
elif self.output_feature == "fbank":
target = self.get_fbank(wav, n_bins=80, sample_rate=curr_sample_rate)
elif self.output_feature == "mfcc":
target = self.get_mfcc(wav, sample_rate=curr_sample_rate)
elif self.output_feature == "logmel":
target = self.get_logmel(wav, sample_rate=curr_sample_rate)
elif self.output_feature == "mulaw":
target = self.mulaw_encode(wav)
target = torch.from_numpy(target).long()
else:
raise ValueError("Unknown extra features {}".format(self.output_feature))
return {"id": index, "input": feats, "target": target, "speaker": self.speakers[index]}
def collater(self, samples):
samples = [s for s in samples if s["input"] is not None]
if len(samples) == 0:
return {}
inputs = [s["input"] for s in samples]
targets = [s["target"] for s in samples]
sizes = [len(s) for s in inputs]
speakers = [s['speaker'] for s in samples]
input_size = min(min(sizes), self.max_sample_size)
if input_size % 2 != 0:
input_size = input_size - 1
"""
if self.input_feature == "wav" or self.input_feature == "mulaw":
if self.output_feature in ["mfcc", "fbank"]:
target_rate = 1.0 / 160
if self.output_feature == "logmel":
target_rate = 1.0 / 160
start_offset = -1
end_offset = 1
elif self.input_feature == "mfcc" or self.input_feature == "fbank":
if self.output_feature not in ["mfcc", "fbank", "logmel"]:
target_rate = 160
elif self.input_feature == "logmel":
if self.output_feature not in ["mfcc", "fbank", "logmel"]:
target_rate = 160
"""
if self.input_feature == self.output_feature:
target_rate = 1
offset = 0
elif self.input_feature == "logmel" and self.output_feature =="mulaw":
target_rate = 160
offset = 1
else:
raise ValueError("Unsupport {} and {}".format(self.input_feature, self.output_feature))
if inputs[0].dim() == 2:
collated_inputs = inputs[0].new_zeros(len(inputs), input_size+offset*2, inputs[0].shape[-1])
else:
collated_inputs = inputs[0].new_zeros(len(inputs), input_size+offset*2)
if targets[0].dim() == 2:
collated_targets = targets[0].new_zeros(len(inputs), (input_size) * target_rate + offset, targets[0].shape[-1])
else:
collated_targets = targets[0].new_zeros(len(inputs), (input_size) * target_rate + offset)
for i, (input, size) in enumerate(zip(inputs, sizes)):
size = len(input)
start = np.random.randint(offset, size - input_size + 1)
collated_inputs[i] = input[start-offset: start + input_size + offset]
collated_targets[i] = targets[i][start * target_rate: (start+input_size) * target_rate + offset]
out = {"id": torch.LongTensor([s["id"] for s in samples]), "speakers":torch.LongTensor(speakers)}
out["input"] = collated_inputs
out["target"] = collated_targets
return out
import importlib
import os
from abc import ABC, abstractmethod
from typing import Dict, Optional
class AudioFeatureTransform(ABC):
@classmethod
@abstractmethod
def from_config_dict(cls, config: Optional[Dict] = None):
pass
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
def register_audio_feature_transform(name):
def register_audio_feature_transform_cls(cls):
if name in AUDIO_FEATURE_TRANSFORM_REGISTRY:
raise ValueError(f"Cannot register duplicate transform ({name})")
if not issubclass(cls, AudioFeatureTransform):
raise ValueError(
f"Transform ({name}: {cls.__name__}) must extend "
"AudioFeatureTransform"
)
if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES:
raise ValueError(
f"Cannot register audio feature transform with duplicate "
f"class name ({cls.__name__})"
)
AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__)
return cls
return register_audio_feature_transform_cls
def get_audio_feature_transform(name):
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
transforms_dir = os.path.dirname(__file__)
for file in os.listdir(transforms_dir):
path = os.path.join(transforms_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module("fairseq.data.audio.feature_transforms." + name)
class CompositeAudioFeatureTransform(AudioFeatureTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
_transforms = _config.get("transforms")
if _transforms is None:
return None
transforms = [
get_audio_feature_transform(_t).from_config_dict(_config.get(_t))
for _t in _transforms
]
return CompositeAudioFeatureTransform(transforms)
def __init__(self, transforms):
self.transforms = [t for t in transforms if t is not None]
def __call__(self, x):
for t in self.transforms:
x = t(x)
return x
def __repr__(self):
format_string = (
[self.__class__.__name__ + "("]
+ [f" {t.__repr__()}" for t in self.transforms]
+ [")"]
)
return "\n".join(format_string)
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("global_cmvn")
class GlobalCMVN(AudioFeatureTransform):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return GlobalCMVN(_config.get("stats_npz_path"))
def __init__(self, stats_npz_path):
self.stats_npz_path = stats_npz_path
stats = np.load(stats_npz_path)
self.mean, self.std = stats["mean"], stats["std"]
def __repr__(self):
return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
def __call__(self, x):
x = np.subtract(x, self.mean)
x = np.divide(x, self.std)
return x
import math
import numbers
from typing import Optional
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("specaugment")
class SpecAugmentTransform(AudioFeatureTransform):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return SpecAugmentTransform(
_config.get("time_warp_W", 0),
_config.get("freq_mask_N", 0),
_config.get("freq_mask_F", 0),
_config.get("time_mask_N", 0),
_config.get("time_mask_T", 0),
_config.get("time_mask_p", 0.0),
_config.get("mask_value", None),
)
def __init__(
self,
time_warp_w: int = 0,
freq_mask_n: int = 0,
freq_mask_f: int = 0,
time_mask_n: int = 0,
time_mask_t: int = 0,
time_mask_p: float = 0.0,
mask_value: Optional[float] = 0.0,
):
# Sanity checks
assert mask_value is None or isinstance(
mask_value, numbers.Number
), f"mask_value (type: {type(mask_value)}) must be None or a number"
if freq_mask_n > 0:
assert freq_mask_f > 0, (
f"freq_mask_F ({freq_mask_f}) "
f"must be larger than 0 when doing freq masking."
)
if time_mask_n > 0:
assert time_mask_t > 0, (
f"time_mask_T ({time_mask_t}) must be larger than 0 when "
f"doing time masking."
)
self.time_warp_w = time_warp_w
self.freq_mask_n = freq_mask_n
self.freq_mask_f = freq_mask_f
self.time_mask_n = time_mask_n
self.time_mask_t = time_mask_t
self.time_mask_p = time_mask_p
self.mask_value = mask_value
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"time_warp_w={self.time_warp_w}",
f"freq_mask_n={self.freq_mask_n}",
f"freq_mask_f={self.freq_mask_f}",
f"time_mask_n={self.time_mask_n}",
f"time_mask_t={self.time_mask_t}",
f"time_mask_p={self.time_mask_p}",
]
)
+ ")"
)
def __call__(self, spectrogram):
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
distorted = spectrogram.copy() # make a copy of input spectrogram.
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
mask_value = self.mask_value
if mask_value is None: # if no value was specified, use local mean.
mask_value = spectrogram.mean()
if num_frames == 0:
return spectrogram
if num_freqs < self.freq_mask_f:
return spectrogram
if self.time_warp_w > 0:
if 2 * self.time_warp_w < num_frames:
import cv2
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
upper, lower = distorted[:w0, :], distorted[w0:, :]
upper = cv2.resize(
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
)
lower = cv2.resize(
lower,
dsize=(num_freqs, num_frames - w0 - w),
interpolation=cv2.INTER_LINEAR,
)
distorted = np.concatenate((upper, lower), axis=0)
for _i in range(self.freq_mask_n):
f = np.random.randint(0, self.freq_mask_f)
f0 = np.random.randint(0, num_freqs - f)
if f != 0:
distorted[:, f0 : f0 + f] = mask_value
max_time_mask_t = min(
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
)
if max_time_mask_t < 1:
return distorted
for _i in range(self.time_mask_n):
t = np.random.randint(0, max_time_mask_t)
t0 = np.random.randint(0, num_frames - t)
if t != 0:
distorted[t0 : t0 + t, :] = mask_value
return distorted
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment