"test/verify/test_rnn_5args.cpp" did not exist on "ba33d25cd3c5acd92d9a8a0c28abb45b288af4f2"
Commit 12c90639 authored by “change”'s avatar “change”
Browse files

init

parent 417b607b
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import logging
import numpy as np
import torch
import os
import itertools
from fairseq.data import FairseqDataset, data_utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
PrependTokenDataset,
data_utils,
indexed_dataset,
)
logger = logging.getLogger(__name__)
def load_langtriple_dataset(
data_path,
split,
src,
src_dict,
ref,
ref_dict,
tgt,
tgt_dict,
combine,
dataset_impl,
upsample_primary,
left_pad_source,
left_pad_target,
max_source_positions,
max_target_positions,
prepend_bos=False,
load_alignments=False,
truncate_source=False,
append_source_id=False,
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
prepend_bos_src=None,
lang_format="[{}]",
):
assert not truncate_source
def split_exists(split, src, ref, tgt, lang, data_path):
filename = os.path.join(data_path, "{}.{}-{}-{}.{}".format(split, src, ref, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
src_datasets = []
ref_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
# infer langcode
if split_exists(split_k, src, ref, tgt, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, src, ref, tgt))
elif split_exists(split_k, tgt, ref, src, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, tgt, ref, src))
else:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
src_dataset = data_utils.load_indexed_dataset(
prefix + src, src_dict, dataset_impl
)
src_datasets.append(src_dataset)
ref_dataset = data_utils.load_indexed_dataset(
prefix + ref, ref_dict, dataset_impl
)
ref_datasets.append(ref_dataset)
tgt_dataset = data_utils.load_indexed_dataset(
prefix + tgt, tgt_dict, dataset_impl
)
if tgt_dataset is not None:
tgt_datasets.append(tgt_dataset)
logger.info(
"{} {} {}-{}-{} {} examples".format(
data_path, split_k, src, ref, tgt, len(src_datasets[-1])
)
)
if not combine:
break
assert len(src_datasets) == len(ref_datasets)
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
if len(src_datasets) == 1:
src_dataset = src_datasets[0]
ref_dataset = ref_datasets[0]
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
else:
sample_ratios = [1] * len(src_datasets)
sample_ratios[0] = upsample_primary
src_dataset = ConcatDataset(src_datasets, sample_ratios)
ref_dataset = ConcatDataset(ref_datasets, sample_ratios)
if len(tgt_datasets) > 0:
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
else:
tgt_dataset = None
if prepend_bos:
assert hasattr(src_dict, "bos_index") and hasattr(ref_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
ref_dataset = PrependTokenDataset(ref_dataset, ref_dict.bos())
if tgt_dataset is not None:
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
elif prepend_bos_src is not None:
logger.info(f"prepending src bos: {prepend_bos_src}")
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
ref_dataset = PrependTokenDataset(ref_dataset, prepend_bos_src)
eos = None
if append_source_id:
src_dataset = AppendTokenDataset(
src_dataset, src_dict.index(lang_format.format(src))
)
ref_dataset = AppendTokenDataset(
ref_dataset, ref_dict.index(lang_format.format(ref))
)
if tgt_dataset is not None:
tgt_dataset = AppendTokenDataset(
tgt_dataset, tgt_dict.index(lang_format.format(tgt))
)
eos = tgt_dict.index(lang_format.format(tgt))
align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(
align_path, None, dataset_impl
)
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
return LanguageTripleDataset(
src_dataset,
src_dataset.sizes,
src_dict,
ref_dataset,
ref_dataset.sizes,
ref_dict,
tgt_dataset,
tgt_dataset_sizes,
tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
align_dataset=align_dataset,
eos=eos,
num_buckets=num_buckets,
shuffle=shuffle,
pad_to_multiple=pad_to_multiple,
)
def collate(
samples,
pad_idx,
eos_idx,
left_pad_source=True,
left_pad_target=False,
input_feeding=True,
pad_to_length=None,
pad_to_multiple=1,
):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
None,
left_pad,
move_eos_to_beginning,
pad_to_length=pad_to_length,
pad_to_multiple=pad_to_multiple,
)
def check_alignment(alignment, src_len, tgt_len):
if alignment is None or len(alignment) == 0:
return False
if (
alignment[:, 0].max().item() >= src_len - 1
or alignment[:, 1].max().item() >= tgt_len - 1
):
logger.warning("alignment size mismatch found, skipping alignment!")
return False
return True
def compute_alignment_weights(alignments):
"""
Given a tensor of shape [:, 2] containing the source-target indices
corresponding to the alignments, a weight vector containing the
inverse frequency of each target index is computed.
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
index 3 is repeated twice)
"""
align_tgt = alignments[:, 1]
_, align_tgt_i, align_tgt_c = torch.unique(
align_tgt, return_inverse=True, return_counts=True
)
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
return 1.0 / align_weights.float()
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
"source",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
ref_tokens = merge(
"reference",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
src_lengths = torch.LongTensor(
[s["source"].ne(pad_idx).long().sum() for s in samples]
)
ref_lengths = torch.LongTensor(
[s["reference"].ne(pad_idx).long().sum() for s in samples]
)
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
ref_lengths = ref_lengths.index_select(0, sort_order)
ref_tokens = ref_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge(
"target",
left_pad=left_pad_target,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
target = target.index_select(0, sort_order)
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
).index_select(0, sort_order)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
elif input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
else:
ntokens = src_lengths.sum().item()
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
"ref_tokens": ref_tokens,
"ref_lengths": ref_lengths,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
0, sort_order
)
if samples[0].get("alignment", None) is not None:
bsz, tgt_sz = batch["target"].shape
src_sz = batch["net_input"]["src_tokens"].shape[1]
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
if left_pad_source:
offsets[:, 0] += src_sz - src_lengths
if left_pad_target:
offsets[:, 1] += tgt_sz - tgt_lengths
alignments = [
alignment + offset
for align_idx, offset, src_len, tgt_len in zip(
sort_order, offsets, src_lengths, tgt_lengths
)
for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
if check_alignment(alignment, src_len, tgt_len)
]
if len(alignments) > 0:
alignments = torch.cat(alignments, dim=0)
align_weights = compute_alignment_weights(alignments)
batch["alignments"] = alignments
batch["align_weights"] = align_weights
if samples[0].get("constraints", None) is not None:
# Collate the packed constraints across the samples, padding to
# the length of the longest sample.
lens = [sample.get("constraints").size(0) for sample in samples]
max_len = max(lens)
constraints = torch.zeros((len(samples), max(lens))).long()
for i, sample in enumerate(samples):
constraints[i, 0 : lens[i]] = samples[i].get("constraints")
batch["constraints"] = constraints.index_select(0, sort_order)
return batch
class LanguageTripleDataset(FairseqDataset):
"""
A pair of torch.utils.data.Datasets.
Args:
src (torch.utils.data.Dataset): source dataset to wrap
src_sizes (List[int]): source sentence lengths
src_dict (~fairseq.data.Dictionary): source vocabulary
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side
(default: True).
left_pad_target (bool, optional): pad target tensors on the left side
(default: False).
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for teacher forcing (default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end
of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments.
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
delimited list of constraints for each sentence.
append_bos (bool, optional): if set, appends bos to the beginning of
source/target sentence.
num_buckets (int, optional): if set to a value greater than 0, then
batches will be bucketed into the given number of batch shapes.
src_lang_id (int, optional): source language ID, if set, the collated batch
will contain a field 'src_lang_id' in 'net_input' which indicates the
source language of the samples.
tgt_lang_id (int, optional): target language ID, if set, the collated batch
will contain a field 'tgt_lang_id' which indicates the target language
of the samples.
"""
def __init__(
self,
src,
src_sizes,
src_dict,
ref,
ref_sizes,
ref_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
left_pad_source=True,
left_pad_target=False,
shuffle=True,
input_feeding=True,
remove_eos_from_source=False,
append_eos_to_target=False,
align_dataset=None,
constraints=None,
append_bos=False,
eos=None,
num_buckets=0,
src_lang_id=None,
tgt_lang_id=None,
pad_to_multiple=1,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
if tgt is not None:
assert len(src) == len(
tgt
), "Source and target must contain the same number of examples"
assert len(src) == len(
ref
), "Source and reference must contain the same number of examples"
self.src = src
self.ref = ref
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.ref_sizes = np.array(ref_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.sizes = (
np.vstack((self.src_sizes, self.tgt_sizes)).T
if self.tgt_sizes is not None
else self.src_sizes
)
self.src_dict = src_dict
self.ref_dict = ref_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.shuffle = shuffle
self.input_feeding = input_feeding
self.remove_eos_from_source = remove_eos_from_source
self.append_eos_to_target = append_eos_to_target
self.align_dataset = align_dataset
if self.align_dataset is not None:
assert (
self.tgt_sizes is not None
), "Both source and target needed when alignments are provided"
self.constraints = constraints
self.append_bos = append_bos
self.eos = eos if eos is not None else src_dict.eos()
self.src_lang_id = src_lang_id
self.tgt_lang_id = tgt_lang_id
if num_buckets > 0:
from fairseq.data import BucketPadLengthDataset
self.src = BucketPadLengthDataset(
self.src,
sizes=self.src_sizes,
num_buckets=num_buckets,
pad_idx=self.src_dict.pad(),
left_pad=self.left_pad_source,
)
self.src_sizes = self.src.sizes
logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
self.ref = BucketPadLengthDataset(
self.ref,
sizes=self.ref_sizes,
num_buckets=num_buckets,
pad_idx=self.ref_dict.pad(),
left_pad=self.left_pad_source,
)
self.ref_sizes = self.ref.sizes
logger.info("bucketing reference lengths: {}".format(list(self.src.buckets)))
if self.tgt is not None:
self.tgt = BucketPadLengthDataset(
self.tgt,
sizes=self.tgt_sizes,
num_buckets=num_buckets,
pad_idx=self.tgt_dict.pad(),
left_pad=self.left_pad_target,
)
self.tgt_sizes = self.tgt.sizes
logger.info(
"bucketing target lengths: {}".format(list(self.tgt.buckets))
)
# determine bucket sizes using self.num_tokens, which will return
# the padded lengths (thanks to BucketPadLengthDataset)
num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long])
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
self.buckets = [
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
]
else:
self.buckets = None
self.pad_to_multiple = pad_to_multiple
def get_batch_shapes(self):
return self.buckets
def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
src_item = self.src[index]
ref_item = self.ref[index]
# Append EOS to end of tgt sentence if it does not have an EOS and remove
# EOS from end of src sentence if it exists. This is useful when we use
# use existing datasets for opposite directions i.e., when we want to
# use tgt_dataset as src_dataset and vice versa
if self.append_eos_to_target:
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
if self.tgt and self.tgt[index][-1] != eos:
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
if self.append_bos:
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
if self.tgt and self.tgt[index][0] != bos:
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
bos = self.src_dict.bos()
if self.src[index][0] != bos:
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
if self.ref[index][0] != bos:
ref_item = torch.cat([torch.LongTensor([bos]), self.ref[index]])
if self.remove_eos_from_source:
eos = self.src_dict.eos()
if self.src[index][-1] == eos:
src_item = self.src[index][:-1]
if self.ref[index][-1] == eos:
ref_item = self.ref[index][:-1]
example = {
"id": index,
"source": src_item,
"reference": ref_item,
"target": tgt_item,
}
if self.align_dataset is not None:
example["alignment"] = self.align_dataset[index]
if self.constraints is not None:
example["constraints"] = self.constraints[index]
return example
def __len__(self):
return len(self.src)
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
pad_to_length (dict, optional): a dictionary of
{'source': source_pad_to_length, 'target': target_pad_to_length}
to indicate the max length to pad to in source and target respectively.
Returns:
dict: a mini-batch with the following keys:
- `id` (LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will
appear on the left if *left_pad_source* is ``True``.
- `src_lengths` (LongTensor): 1D Tensor of the unpadded
lengths of each source sentence of shape `(bsz)`
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one
position for teacher forcing, of shape `(bsz, tgt_len)`.
This key will not be present if *input_feeding* is
``False``. Padding will appear on the left if
*left_pad_target* is ``True``.
- `src_lang_id` (LongTensor): a long Tensor which contains source
language IDs of each sample in the batch
- `target` (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the left if *left_pad_target* is ``True``.
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language
IDs of each sample in the batch
"""
res = collate(
samples,
pad_idx=self.src_dict.pad(),
eos_idx=self.eos,
left_pad_source=self.left_pad_source,
left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
pad_to_length=pad_to_length,
pad_to_multiple=self.pad_to_multiple,
)
if self.src_lang_id is not None or self.tgt_lang_id is not None:
src_tokens = res["net_input"]["src_tokens"]
bsz = src_tokens.size(0)
if self.src_lang_id is not None:
res["net_input"]["src_lang_id"] = (
torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
)
if self.tgt_lang_id is not None:
res["tgt_lang_id"] = (
torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
)
return res
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return max(
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
sizes = self.src_sizes[indices]
if self.tgt_sizes is not None:
sizes = np.maximum(sizes, self.tgt_sizes[indices])
return sizes
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self)).astype(np.int64)
else:
indices = np.arange(len(self), dtype=np.int64)
if self.buckets is None:
# sort by target length, then source length
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
else:
# sort by bucketed_num_tokens, which is:
# max(padded_src_len, padded_tgt_len)
return indices[
np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
]
@property
def supports_prefetch(self):
return getattr(self.src, "supports_prefetch", False) and (
getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
)
def prefetch(self, indices):
self.src.prefetch(indices)
if self.tgt is not None:
self.tgt.prefetch(indices)
if self.align_dataset is not None:
self.align_dataset.prefetch(indices)
def filter_indices_by_size(self, indices, max_sizes):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
return data_utils.filter_paired_dataset_indices_by_size(
self.src_sizes,
self.tgt_sizes,
indices,
max_sizes,
)
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/tasks/translation.py
1. Add custom lang_format in function load_langpair_dataset
2. If truncate_source (default no), use RandomCropDataset instead of TruncateDataset
"""
import itertools
import logging
import os
from fairseq.data import (
AppendTokenDataset,
LanguagePairDataset,
PrependTokenDataset,
StripTokenDataset,
TruncateDataset,
RandomCropDataset,
data_utils,
indexed_dataset,
)
from speechlm.data.concat_dataset import ConcatDataset
EVAL_BLEU_ORDER = 4
logger = logging.getLogger(__name__)
def load_langpair_dataset(
data_path,
split,
src,
src_dict,
tgt,
tgt_dict,
combine,
dataset_impl,
upsample_primary,
left_pad_source,
left_pad_target,
max_source_positions,
max_target_positions,
prepend_bos=False,
load_alignments=False,
truncate_source=False,
append_source_id=False,
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
prepend_bos_src=None,
lang_format="[{}]",
input_feeding=True,
):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
# infer langcode
if split_exists(split_k, src, tgt, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
src_dataset = data_utils.load_indexed_dataset(
prefix + src, src_dict, dataset_impl
)
if truncate_source:
src_dataset = AppendTokenDataset(
RandomCropDataset(
StripTokenDataset(src_dataset, src_dict.eos()),
max_source_positions - 1,
),
src_dict.eos(),
)
src_datasets.append(src_dataset)
tgt_dataset = data_utils.load_indexed_dataset(
prefix + tgt, tgt_dict, dataset_impl
)
if tgt_dataset is not None:
tgt_datasets.append(tgt_dataset)
logger.info(
"{} {} {}-{} {} examples".format(
data_path, split_k, src, tgt, len(src_datasets[-1])
)
)
if not combine:
break
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
if len(src_datasets) == 1:
src_dataset = src_datasets[0]
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
else:
sample_ratios = [1] * len(src_datasets)
sample_ratios[0] = upsample_primary
src_dataset = ConcatDataset(src_datasets, sample_ratios)
if len(tgt_datasets) > 0:
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
else:
tgt_dataset = None
if prepend_bos:
assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
if tgt_dataset is not None:
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
elif prepend_bos_src is not None:
logger.info(f"prepending src bos: {prepend_bos_src}")
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
eos = None
if append_source_id:
src_dataset = AppendTokenDataset(
src_dataset, src_dict.index(lang_format.format(src))
)
if tgt_dataset is not None:
tgt_dataset = AppendTokenDataset(
tgt_dataset, tgt_dict.index(lang_format.format(tgt))
)
eos = tgt_dict.index(lang_format.format(tgt))
align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(
align_path, None, dataset_impl
)
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
return LanguagePairDataset(
src_dataset,
src_dataset.sizes,
src_dict,
tgt_dataset,
tgt_dataset_sizes,
tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
align_dataset=align_dataset,
eos=eos,
num_buckets=num_buckets,
shuffle=shuffle,
pad_to_multiple=pad_to_multiple,
input_feeding=input_feeding,
)
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import logging
from os import replace
import time
from collections import OrderedDict
from typing import Any, Dict, List, Optional
import numpy as np
from fairseq.data import data_utils
from fairseq.data import FairseqDataset
logger = logging.getLogger(__name__)
class MultiCorpusDataset(FairseqDataset):
"""
see fairseq/fairseq/data/multi_corpus_dataset.__doc__
Args:
datasets: a OrderedDict of FairseqDataset instances.
distribution: a List containing the probability of getting an utterance from
corresponding dataset
seed: random seed for sampling the datsets
sort_indices: if true, will sort the ordered indices by size
batch_sample: if true, will ensure each batch is from a single dataset
"""
def __init__(
self,
datasets: Dict[str, FairseqDataset],
max_positions: Dict,
distribution: List[float],
max_tokens_ratio: List[float],
seed: int = 1234,
sort_indices: bool = False,
check_length: bool = False,
):
super().__init__()
assert isinstance(datasets, OrderedDict)
assert len(datasets) == len(distribution)
# assert sum(distribution) == 1
self.datasets = datasets
self.distribution = distribution
self.max_tokens_ratio = max_tokens_ratio
self.seed = seed
self.sort_indices = sort_indices
self.max_positions = max_positions
self.check_length = check_length
# Avoid repeated conversions to list later
self.dataset_list = list(datasets.values())
self.total_num_instances = 0
# first_dataset = self.dataset_list[0]
self.num_instances_per_dataset = []
self.dataset_offsets = []
for i, dataset in enumerate(self.dataset_list):
assert isinstance(dataset, FairseqDataset)
# assert type(dataset) is type(first_dataset)
self.num_instances_per_dataset.append(
0 if self.distribution[i] == 0 else len(dataset)
)
self.dataset_offsets.append(self.total_num_instances)
self.total_num_instances += self.num_instances_per_dataset[i]
def ordered_indices(self):
start = time.time()
with data_utils.numpy_seed(self.seed, self.epoch):
logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}")
sampled_indices = {}
# For each dataset i, sample self.distribution[i] * self.total_num_instances
for i, key in enumerate(self.datasets):
tp = time.time()
if self.distribution[i] == 0:
# skip dataset if sampling probability is 0
continue
if i < len(self.datasets) - 1:
num_instances = int(self.distribution[i] * self.total_num_instances)
high = self.dataset_offsets[i + 1]
else:
num_instances = int(self.distribution[i] * self.total_num_instances)
high = self.total_num_instances
logger.info(f"sampling {num_instances} from {key} dataset")
# First, add k copies of the dataset where k = num_instances // len(dataset).
# This ensures an equal distribution of the data points as much as possible.
# For the remaining entries randomly sample them
dataset_size = len(self.datasets[key])
num_copies = num_instances // dataset_size
dataset_indices = np.random.permutation(high - self.dataset_offsets[i])[: num_instances - num_copies * dataset_size]
if num_copies > 0:
dataset_indices = np.concatenate(
(
np.repeat(
np.arange(high - self.dataset_offsets[i]), num_copies
),
dataset_indices,
)
)
# filter by size, we should ignore it by setting check_length=False
# , as it is very time-consuming on large dadaset
if self.max_positions[key] is not None and self.check_length:
dataset_indices, ignored = self.datasets[key].filter_indices_by_size(
dataset_indices,
self.max_positions[key],
)
if len(ignored) > 0:
logger.warning(
(
"{:,} samples have invalid sizes and will be skipped, "
"max_positions={}, first few sample ids={}"
).format(len(ignored), self.max_positions[key], ignored[:10])
)
if self.sort_indices:
logger.info(" - sampled indices took {}s".format(time.time() - tp))
tp = time.time()
dataset_indices = np.sort(dataset_indices)
ordered_indices = self.datasets[key].ordered_indices()
if isinstance(ordered_indices[0], np.ndarray): # chunked audio data
dataset_indices = [order_idx + self.dataset_offsets[i] for order_idx in ordered_indices]
assert self.dataset_offsets[i] == 0
# TODO for chunked audio data, now assume len(dataset_indices) == len(dataset). Don't filter any data.
else:
dataset_indices = ordered_indices[dataset_indices] + self.dataset_offsets[i]
logger.info(" - ordered_indices took {}s".format(time.time() - tp))
else:
np.random.shuffle(dataset_indices)
sampled_indices[key] = dataset_indices
logger.info(
"multi_corpus_dataset ordered_indices took {}s".format(
time.time() - start
)
)
return sampled_indices
def _map_index(self, index: int):
"""
If dataset A has length N and dataset B has length M
then index 1 maps to index 1 of dataset A, and index N + 1
maps to index 1 of B.
"""
counter = 0
for num_instances, key in zip(self.num_instances_per_dataset, self.datasets):
if index < counter + num_instances:
return index - counter, key
counter += num_instances
raise ValueError(
"Invalid index: {}, max: {}".format(index, self.total_num_instances)
)
def __len__(self):
"""
Length of this dataset is the sum of individual datasets
"""
return self.total_num_instances
def __getitem__(self, index):
new_index, key = self._map_index(index)
try:
item = self.datasets[key][new_index]
item["full_id"] = index
return item
except Exception as e:
e.args = (f"Error from {key} dataset", *e.args)
raise
def collater(self, samples):
"""
If we are doing batch sampling, then pick the right collater to use.
Otherwise we assume all collaters are the same.
"""
if len(samples) == 0:
return None
samples_dict = {key: [] for key in self.datasets}
for s in samples:
_, key = self._map_index(s["full_id"])
samples_dict[key].append(s)
batch = {}
for key in samples_dict:
if len(samples_dict[key]) == 0:
continue
batch[key] = self.datasets[key].collater(samples_dict[key])
return batch
def num_tokens(self, index: int):
index, key = self._map_index(index)
return self.datasets[key].num_tokens(index)
def size(self, index: int):
index, key = self._map_index(index)
return self.datasets[key].size(index)
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
logger.info(f"setting epoch of multi_corpus_dataset to {epoch}")
for ds in self.dataset_list:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
self.epoch = epoch
@property
def supports_prefetch(self):
return False
@property
def supports_fetch_outside_dataloader(self):
return all(
self.datasets[key].supports_fetch_outside_dataloader
for key in self.datasets
)
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
dataset_indices = indices
batches_dict = {}
for n, key in enumerate(dataset_indices):
max_tokens_ratio = self.max_tokens_ratio[n]
if isinstance(dataset_indices[key][0], np.ndarray): # chunked audio data
cur_batches = self.datasets[key].batch_by_size(
dataset_indices[key],
round(max_tokens * max_tokens_ratio),
max_sentences,
required_batch_size_multiple,
)
logger.info(f"Created {sum([len(b) for b in cur_batches])} [{len(cur_batches)}] batches for dataset {key}")
else:
cur_batches = super().batch_by_size(
np.array(dataset_indices[key], dtype=np.int64),
round(max_tokens * max_tokens_ratio),
max_sentences,
required_batch_size_multiple,
)
logger.info(f"Created {len(cur_batches)} batches for dataset {key}")
batches_dict[key] = cur_batches
return batches_dict
def get_batch_sampler(
self,
indices,
num_shards,
seed,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
split_modality_batch=False,
):
def batch_sampler(dataset, epoch):
start = time.time()
batches_dict = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
logger.info(f"multi_corpus_dataset, batch_by_size took {time.time() - start}s")
start = time.time()
new_batches = []
### shuffle inner group size, split into speech/text batches
shuffled_batches_list = []
speech_batches = []
### we should specify the speech_batches because: we need concatenate different speech datasets
# (e.g. ltr or km) instead of loading them parellelly.
for name, batches in batches_dict.items():
if name.startswith("speech"):
if isinstance(batches[0], list): # chunked audio data
batches = self.datasets[name].shuffle_batches(list(batches), seed + epoch)
shuffled_batches_list.append(batches)
else:
batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
batches = batches[: (len(batches) // num_shards) * num_shards]
if len(batches) == 0:
logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
else:
speech_batches += batches
else:
batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
batches = batches[: (len(batches) // num_shards) * num_shards]
if len(batches) == 0:
logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
else:
batches = shuffle_buckets(batches, seed=seed+epoch, inner_shuf=False)
shuffled_batches_list.append(batches)
if len(speech_batches) > 0:
speech_batches = shuffle_buckets(speech_batches, seed=seed+epoch, inner_shuf=False)
shuffled_batches_list.append(speech_batches)
### create the final new_batches
num_batch = min(len(batches) for batches in shuffled_batches_list)
if split_modality_batch:
for i in range(0, num_batch, num_shards):
for batches in shuffled_batches_list:
new_batches += batches[i: i + num_shards]
else:
for i in range(num_batch):
new_batches.append(np.concatenate([batches[i] for batches in shuffled_batches_list]))
logger.info(f"multi_corpus_dataset sample {len(new_batches)} batches, took {time.time() - start}s")
return new_batches
def inner_bucket_shuffle(batches, seed, bucket_size=10, thr=0):
"""we assert batches is sorted form long to short.
shuffle samples in a buctet(e.g. 10 batches).
batches: a list of numpy array"""
num_batch = len(batches)
new_batches = []
num_buckets = len(batches) // bucket_size
i = 0
while i < num_batch:
if (i < bucket_size * thr or
i >= bucket_size * (num_buckets - thr)
):
new_batches.append(batches[i])
i += 1
else:
group = np.concatenate(batches[i: i+bucket_size])
with data_utils.numpy_seed(seed):
np.random.shuffle(group)
new_batches += np.array_split(group, bucket_size)
i += bucket_size
assert all([len(batch) > 0 for batch in new_batches])
return new_batches
def shuffle_buckets(batches, seed, inner_shuf=True):
if inner_shuf:
batches = inner_bucket_shuffle(batches, seed, num_shards*10)
batches = [batches[i: i + num_shards] for i in range(0, len(batches)-num_shards+1, num_shards)]
assert len(batches[-1]) == num_shards
new_batches = []
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
for group in batches:
new_batches += group
return new_batches
return batch_sampler
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
from pathlib import Path
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
import numpy as np
import torch
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
SpeechToTextDatasetCreator,
S2TDataConfig,
_collate_frames,
get_features_or_waveform,
)
from fairseq.data import Dictionary, data_utils as fairseq_data_utils
@dataclass
class TextToUnitDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
speaker_id: Optional[int] = None
speaker_emb: Optional[torch.Tensor] = None
duration: Optional[torch.Tensor] = None
pitch: Optional[torch.Tensor] = None
energy: Optional[torch.Tensor] = None
class Text2UnitDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
unit_labels: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
durations: Optional[List[List[int]]] = None,
pitches: Optional[List[str]] = None,
energies: Optional[List[str]] = None,
):
super(Text2UnitDataset, self).__init__(
split,
is_train_split,
cfg,
unit_labels,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
self.durations = durations
self.pitches = pitches
self.energies = energies
self.unit_labels = unit_labels
self.feature_root = Path(cfg.audio_root)
self.spk_emb_type = cfg.config.get("speaker_embedding_type", None)
self.random_spk = cfg.config.get("random_speaker", False)
if self.spk_emb_type is not None:
self.spk_emb_choices = [i for i in (self.feature_root / self.spk_emb_type).glob("*.npy")]
self.spk_emb_num = len(self.spk_emb_choices)
def __getitem__(self, index: int) -> TextToUnitDatasetItem:
# s2t_item = super().__getitem__(index)
source = torch.LongTensor(self.unit_labels[index])
target = None
if self.tgt_texts is not None:
tokenized = self.get_tokenized_tgt_text(index)
target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=self.append_eos
).long()
if self.cfg.prepend_tgt_lang_tag:
lang_tag_idx = self.get_lang_tag_idx(
self.tgt_langs[index], self.tgt_dict
)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
speaker_id = None
if self.speaker_to_id is not None:
speaker_id = self.speaker_to_id[self.speakers[index]]
speaker_emb = None
if self.spk_emb_type is not None:
if self.random_spk:
spk_emb_path = self.spk_emb_choices[np.random.choice(self.spk_emb_num)]
else:
spk_emb_path = self.feature_root / self.spk_emb_type / f"{self.ids[index]}.npy"
speaker_emb = get_features_or_waveform(spk_emb_path)
speaker_emb = torch.from_numpy(speaker_emb).float()
duration, pitch, energy = None, None, None
if self.durations is not None:
duration = torch.tensor(
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
)
if self.pitches is not None:
pitch = get_features_or_waveform(self.pitches[index])
pitch = torch.from_numpy(
np.concatenate((pitch, [0])) # pad 0 for EOS
).float()
if self.energies is not None:
energy = get_features_or_waveform(self.energies[index])
energy = torch.from_numpy(
np.concatenate((energy, [0])) # pad 0 for EOS
).float()
return TextToUnitDatasetItem(
index=index,
source=source,
target=target,
speaker_id=speaker_id,
speaker_emb=speaker_emb,
duration=duration,
pitch=pitch,
energy=energy,
)
def collater(self, samples: List[TextToUnitDatasetItem]) -> Dict[str, Any]:
if len(samples) == 0:
return {}
src_lengths, order = torch.tensor(
[s.target.shape[0] for s in samples], dtype=torch.long
).sort(descending=True)
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
0, order
)
traget = fairseq_data_utils.collate_tokens(
[s.source for s in samples],
self.tgt_dict.pad(),
).index_select(0, order)
target_lengths = torch.tensor(
[s.source.shape[0] for s in samples], dtype=torch.long
).index_select(0, order)
src_tokens = fairseq_data_utils.collate_tokens(
[s.target for s in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
).index_select(0, order)
speaker = None
if self.speaker_to_id is not None:
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
if self.spk_emb_type is not None:
speaker = torch.stack([s.speaker_emb for s in samples], dim=0).index_select(0, order)
bsz, _ = traget.size()
prev_output_tokens = torch.cat(
(traget.new_zeros((bsz, self.tgt_dict.bos())), traget[:, :-1]), dim=1
)
durations, pitches, energies = None, None, None
if self.durations is not None:
durations = fairseq_data_utils.collate_tokens(
[s.duration for s in samples], 0
).index_select(0, order)
assert src_tokens.shape[1] == durations.shape[1]
if self.pitches is not None:
pitches = _collate_frames([s.pitch for s in samples], True)
pitches = pitches.index_select(0, order)
assert src_tokens.shape[1] == pitches.shape[1]
if self.energies is not None:
energies = _collate_frames([s.energy for s in samples], True)
energies = energies.index_select(0, order)
assert src_tokens.shape[1] == energies.shape[1]
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
return {
"id": id_,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens,
},
"speaker": speaker,
"target": traget,
"durations": durations,
"pitches": pitches,
"energies": energies,
"target_lengths": target_lengths,
"ntokens": sum(target_lengths).item(),
"nsentences": len(samples),
"src_texts": src_texts,
}
class Text2UnitDatasetCreator(SpeechToTextDatasetCreator):
KEY_DURATION = "duration"
KEY_PITCH = "pitch"
KEY_ENERGY = "energy"
KEY_UNIT = "unit"
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
) -> Text2UnitDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
# audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
unit_labels = [s[cls.KEY_UNIT] for s in samples]
unit_labels = [
None if dd is None else [int(d) for d in dd.split(" ")] for dd in unit_labels
]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
durations = [
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
]
durations = None if any(dd is None for dd in durations) else durations
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
pitches = [
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
]
pitches = None if any(pp is None for pp in pitches) else pitches
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
energies = [
None if ee is None else (audio_root / ee).as_posix() for ee in energies
]
energies = None if any(ee is None for ee in energies) else energies
return Text2UnitDataset(
split_name,
is_train_split,
cfg,
unit_labels,
n_frames,
src_texts,
tgt_texts,
speakers,
src_langs,
tgt_langs,
ids,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
durations,
pitches,
energies,
)
import argparse
from tqdm import tqdm
from pydub import AudioSegment
import torchaudio
import os
def mp3_convert_wav(mp3_file, wav_file):
try:
sound = AudioSegment.from_mp3(mp3_file)
sound=sound.set_frame_rate(16000)
sound=sound.set_channels(1)
sound=sound.set_sample_width(2)
sound.export(wav_file, format="wav")
except Exception as e:
print(e)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", required=True, type=str)
parser.add_argument("--shard", "-n", required=True, type=int)
parser.add_argument("--rank", "-r", required=True, type=int)
args = parser.parse_args()
assert args.rank < args.shard, f"rank: {args.rank} >= shard: {args.shard}"
with open(args.input, 'r') as f:
files = [line.strip() for line in f ]
mp3_files = files[args.rank::args.shard]
for mp3_file in tqdm(mp3_files):
wav_file = mp3_file.replace("/clips/", "/wav/").replace(".mp3", ".wav")
if os.path.exists(wav_file):
try:
torchaudio.info(wav_file)
except Exception as e:
print(e)
mp3_convert_wav(mp3_file, wav_file)
else:
mp3_convert_wav(mp3_file, wav_file)
if __name__ == "__main__":
main()
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
"""
Modified from: https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/examples/speech_to_text/prep_covost_data.py
1. normalize the punctuation
2. instead of extract fbank features, we direcly use 16k-Hz waveform
"""
import argparse
import logging
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Optional, Tuple
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
filter_manifest_df,
gen_config_yaml,
gen_vocab,
load_df_from_tsv,
save_df_to_tsv,
)
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive
from tqdm import tqdm
from pydub import AudioSegment
import soundfile as sf
import sacremoses
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text"]
def mp3_convert_wav(mp3_file, wav_file):
sound = AudioSegment.from_mp3(mp3_file)
sound=sound.set_frame_rate(16000)
sound=sound.set_channels(1)
sound=sound.set_sample_width(2)
sound.export(wav_file, format="wav")
class CoVoST(Dataset):
"""Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
Args:
root (str): root path to the dataset and generated manifests/features
source_language (str): source (audio) language
target_language (str, optional): target (text) language,
None for no translation (default: None)
version (int, optional): CoVoST version. (default: 2)
download (bool, optional): Whether to download the dataset if it is not
found at root path. (default: ``False``).
"""
COVOST_URL_TEMPLATE = (
"https://dl.fbaipublicfiles.com/covost/"
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
)
VERSIONS = {2}
SPLITS = ["train", "dev", "test"]
XX_EN_LANGUAGES = {
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
2: [
"fr",
"de",
"es",
"ca",
"it",
"ru",
"zh-CN",
"pt",
"fa",
"et",
"mn",
"nl",
"tr",
"ar",
"sv-SE",
"lv",
"sl",
"ta",
"ja",
"id",
"cy",
],
}
EN_XX_LANGUAGES = {
1: [],
2: [
"de",
"tr",
"fa",
"sv-SE",
"mn",
"zh-CN",
"cy",
"ca",
"sl",
"et",
"id",
"ar",
"ta",
"lv",
"ja",
],
}
def __init__(
self,
root: str,
split: str,
source_language: str,
target_language: Optional[str] = None,
version: int = 2,
) -> None:
assert version in self.VERSIONS and split in self.SPLITS
assert source_language is not None
self.no_translation = target_language is None
if not self.no_translation:
assert "en" in {source_language, target_language}
if source_language == "en":
assert target_language in self.EN_XX_LANGUAGES[version]
else:
assert source_language in self.XX_EN_LANGUAGES[version]
else:
# Hack here so that we can get "split" column from CoVoST TSV.
# Note that we use CoVoST train split for ASR which is an extension
# to Common Voice train split.
target_language = "de" if source_language == "en" else "en"
self.root: Path = Path(root)
cv_tsv_path = self.root / "validated.tsv"
assert cv_tsv_path.is_file()
covost_url = self.COVOST_URL_TEMPLATE.format(
src_lang=source_language, tgt_lang=target_language
)
covost_archive = self.root / Path(covost_url).name
if not covost_archive.is_file():
download_url(covost_url, self.root.as_posix(), hash_value=None)
extract_archive(covost_archive.as_posix())
cv_tsv = load_df_from_tsv(cv_tsv_path)
covost_tsv = load_df_from_tsv(
self.root / Path(covost_url).name.replace(".tar.gz", "")
)
df = pd.merge(
left=cv_tsv[["path", "sentence", "client_id"]],
right=covost_tsv[["path", "translation", "split"]],
how="inner",
on="path",
)
if split == "train":
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
else:
df = df[df["split"] == split]
data = df.to_dict(orient="index").items()
data = [v for k, v in sorted(data, key=lambda x: x[0])]
self.data = []
for e in data:
try:
path = self.root / "clips" / e["path"]
_ = torchaudio.info(path.as_posix())
self.data.append(e)
except RuntimeError:
pass
self.normalizer = sacremoses.MosesPunctNormalizer(
lang=target_language,
pre_replace_unicode_punct=True,
post_remove_control_chars=True,
)
def __getitem__(
self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
sample_id)``
"""
data = self.data[n]
path = self.root / "clips" / data["path"]
# waveform, sample_rate = torchaudio.load(path)
sentence = data["sentence"]
translation = None if self.no_translation else data["translation"]
translation = self.normalizer.normalize(translation)
speaker_id = data["client_id"]
_id = data["path"].replace(".mp3", "")
return path, -1, sentence, translation, speaker_id, _id
def __len__(self) -> int:
return len(self.data)
def process(args):
root = Path(args.data_root).absolute() / args.src_lang
outroot = root / f"{args.src_lang}-{args.tgt_lang}"
if args.vocab_type != "char":
outroot = root / f"{args.src_lang}-{args.tgt_lang}-{args.vocab_type}"
if not root.is_dir():
raise NotADirectoryError(f"{root} does not exist")
#1. Extract featuress
# mp3-to-wav can take long long time, better run it externally with multi threads.
feature_root = root / "wav"
# feature_root.mkdir(exist_ok=True)
# for split in CoVoST.SPLITS:
# print(f"Fetching split {split}...")
# dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
# print("Converting mp3 to wav...")
# handle = open(root / f"{split}.id", "w")
# for waveform, _, _, _, _, utt_id in tqdm(dataset):
# wav_file = feature_root / f"{utt_id}.wav"
# print(waveform, file=handle)
# mp3_convert_wav(waveform, wav_file)
#2. Generate TSV manifest
print("Generating manifest...")
train_text = []
task = f"asr_{args.src_lang}"
if args.tgt_lang is not None:
task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for waveform, _, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
wav_file = feature_root / f"{utt_id}.wav"
manifest["id"].append(utt_id)
manifest["audio"].append(wav_file.as_posix().replace("/data/", "/mnt/default/"))
manifest["n_frames"].append(sf.info(wav_file).frames)
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
is_train_split = split.startswith("train")
if is_train_split:
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split, min_n_frames=320, max_n_frames=480000)
save_df_to_tsv(df, outroot / f"{split}_{task}.tsv")
# Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
outroot / spm_filename_prefix,
args.vocab_type,
args.vocab_size
)
# Generate config YAML
# gen_config_yaml(
# outroot,
# spm_filename=spm_filename_prefix + ".model",
# yaml_filename=f"config_{task}.yaml",
# specaugment_policy="lb",
# )
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-root", "-d", required=True, type=str,
help="data root with sub-folders for each language <root>/<src_lang>"
)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=1000, type=int)
parser.add_argument("--src-lang", "-s", required=True, type=str)
parser.add_argument("--tgt-lang", "-t", type=str)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import os
import argparse
from tqdm import tqdm
import numpy as np
lg_label = "__label__{}"
def writefile(filename, lines):
with open(filename, 'w', encoding='utf-8') as f:
f.writelines(lines)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", required=True, type=str)
parser.add_argument("--output", "-o", required=True, type=str)
parser.add_argument("--src", "-s", required=True, type=str)
parser.add_argument("--tgt", "-t", required=True, type=str)
parser.add_argument("--max-len", "-m", default=2998, type=int)
args = parser.parse_args()
src_lines, tgt_lines = [], []
with open(f"{args.input}.{args.src}", 'r') as f1, open(f"{args.input}.{args.tgt}", 'r') as f2:
for src_line, tgt_line in tqdm(zip(f1, f2)):
src_len = len(src_line.strip().split())
tgt_len = len(tgt_line.strip().split())
if src_len < args.max_len and src_len > 0 and tgt_len < args.max_len and tgt_len > 0:
src_lines.append(src_line)
tgt_lines.append(tgt_line)
writefile(f"{args.output}.{args.src}", src_lines)
writefile(f"{args.output}.{args.tgt}", tgt_lines)
if __name__ == "__main__":
main()
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import argparse
import logging
from pathlib import Path
from collections import defaultdict
import pandas as pd
import torchaudio
from tqdm import tqdm
import numpy as np
import torch
from fairseq.data.audio.audio_utils import convert_waveform
from examples.speech_to_text.data_utils import save_df_to_tsv
from examples.speech_synthesis.data_utils import extract_pitch
log = logging.getLogger(__name__)
def get_duration(fa_phone):
"""fa_phone: force-aligned phone, 1-D numpy"""
same = np.concatenate(([True], fa_phone[:-1] != fa_phone[1:], [True]))
index = np.where(same)[0]
count = np.diff(index)
return count
def process(args):
# assert "train" in args.splits
out_root = Path(args.output_root).absolute()
out_root.mkdir(exist_ok=True)
print("Fetching data...")
audio_manifest_root = Path(args.audio_manifest_root).absolute()
for s in args.splits:
if args.add_pitch:
pitch_root = out_root / "pitch" / s
pitch_root.mkdir(exist_ok=True)
manifest = defaultdict(list)
with open(audio_manifest_root / f"{s}.audio.tsv") as f1, \
open(audio_manifest_root / f"{s}.phn") as f2, \
open(audio_manifest_root / f"{s}.km") as f3:
audio_root = f1.readline().strip()
audio_root = Path(audio_root)
for audio_path, fa_phone, fa_unit in tqdm(zip(f1, f2, f3)):
record = True
audio_path, n_frames = audio_path.strip().split("\t")
fa_phone = fa_phone.strip().split()
fa_unit = fa_unit.strip()
uttid = audio_path.split("/")[-1].split(".")[0]
speaker = uttid.split("-")[0]
if args.add_duration:
assert len(fa_phone) == len(fa_unit.split())
fa_phone = np.array(list(map(int, fa_phone)))
duration = get_duration(fa_phone)
reduced_phone = torch.LongTensor(fa_phone).unique_consecutive().numpy()
if args.add_pitch:
pitch_path = pitch_root / f"{uttid}.npy"
if not pitch_path.is_file():
waveform, sample_rate = torchaudio.load(audio_root / audio_path)
waveform, sample_rate = convert_waveform(
waveform, sample_rate, normalize_volume=args.normalize_volume,
)
pitch = extract_pitch(
waveform, sample_rate, None,
hop_length=args.hop_length, log_scale=True,
phoneme_durations=duration
)
if pitch is not None:
np.save(pitch_path.as_posix(), pitch)
else:
record = False
else:
reduced_phone = fa_phone
if record:
manifest["id"].append(uttid)
manifest["speaker"].append(speaker)
manifest["n_frames"].append(len(fa_unit.split()))
manifest["tgt_text"].append(" ".join(map(str, reduced_phone)))
manifest["unit"].append(fa_unit)
if args.add_duration:
manifest["duration"].append(" ".join(map(str, duration)))
if args.add_pitch:
manifest["pitch"].append(f"pitch/{s}/{uttid}.npy")
save_df_to_tsv(
pd.DataFrame.from_dict(manifest),
out_root / f"{s}.tsv"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--audio-manifest-root", "-m", type=str)
parser.add_argument("--output-root", "-o", required=True, type=str)
parser.add_argument("--splits", "-s", type=str, nargs="+",
default=["train", "dev", "test"])
parser.add_argument("--normalize-volume", "-n", action="store_true")
parser.add_argument("--hop-length", type=int, default=256)
parser.add_argument("--add-duration", action="store_true")
parser.add_argument("--add-pitch", action="store_true")
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import argparse
import logging
from pathlib import Path
from collections import defaultdict
import pandas as pd
from tqdm import tqdm
import numpy as np
from examples.speech_to_text.data_utils import save_df_to_tsv
log = logging.getLogger(__name__)
def get_duration(fa_phone):
"""fa_phone: force-aligned phone, 1-D numpy"""
same = np.concatenate(([True], fa_phone[:-1] != fa_phone[1:], [True]))
index = np.where(same)[0]
count = np.diff(index)
return count
def process(args):
# assert "train" in args.splits
out_root = Path(args.output_root).absolute()
out_root.mkdir(exist_ok=True)
print("Fetching data...")
audio_manifest_root = Path(args.audio_manifest_root).absolute()
for s in args.splits:
manifest = defaultdict(list)
with open(audio_manifest_root / f"{s}.phn") as f1:
for i, reduced_phone in tqdm(enumerate(f1)):
reduced_phone = reduced_phone.strip()
uttid = f"librilm-{i}"
speaker = uttid.split("-")[0]
manifest["id"].append(uttid)
manifest["speaker"].append(speaker)
manifest["n_frames"].append(len(reduced_phone))
manifest["tgt_text"].append(reduced_phone)
manifest["unit"].append(0)
save_df_to_tsv(
pd.DataFrame.from_dict(manifest),
out_root / f"{s}.tsv"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--audio-manifest-root", "-m", type=str)
parser.add_argument("--output-root", "-o", required=True, type=str)
parser.add_argument("--splits", "-s", type=str, nargs="+",
default=["train", "dev", "test"])
parser.add_argument("--add-fastspeech-targets", action="store_true")
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py
"""
import argparse
import numpy as np
import sys
from g2p_en import G2p
from tqdm import tqdm
import logging
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def get_parser():
parser = argparse.ArgumentParser(
description="converts words to phones adding optional silences around in between words"
)
parser.add_argument(
"--sil-prob",
"-s",
type=float,
default=0,
help="probability of inserting silence between each word",
)
parser.add_argument(
"--surround",
action="store_true",
help="if set, surrounds each example with silence",
)
parser.add_argument(
"--lexicon",
help="lexicon to convert to phones",
required=True,
)
parser.add_argument(
"--strict",
action="store_true",
help="if set, OOV words will raise a error (for train/valid set)",
)
parser.add_argument(
"--input",
"-i",
help="input text file",
required=True,
)
parser.add_argument(
"--output",
"-o",
help="input text file",
required=True,
)
return parser
def normalize_phn(phons):
"""
convert g2p style phone to 39-phone set
"""
return [p.rstrip('0123456789') for p in phons]
def main():
parser = get_parser()
args = parser.parse_args()
sil_prob = args.sil_prob
surround = args.surround
sil = "<SIL>"
wrd_to_phn = {}
g2p = G2p()
with open(args.lexicon, "r") as lf:
for line in lf:
items = line.rstrip().split()
assert len(items) > 1, line
assert items[0] not in wrd_to_phn, items
wrd_to_phn[items[0]] = items[1:]
with open(args.input, "r") as fin, open(args.output, "w", encoding="utf-8") as fout:
for line in tqdm(fin):
words = line.strip().upper().split()
if not all(w in wrd_to_phn for w in words):
if args.strict:
# logger.warning(f"| Warning: OOV words found: {line}")
pass
else:
continue
phones = []
if surround:
phones.append(sil)
sample_sil_probs = None
if sil_prob > 0 and len(words) > 1:
sample_sil_probs = np.random.random(len(words) - 1)
for i, w in enumerate(words):
if w in wrd_to_phn:
phones.extend(wrd_to_phn[w])
else:
phones.extend(normalize_phn(g2p(w)))
if (
sample_sil_probs is not None
and i < len(sample_sil_probs)
and sample_sil_probs[i] < sil_prob
):
phones.append(sil)
if surround:
phones.append(sil)
print(" ".join(phones), file=fout)
if __name__ == "__main__":
main()
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import os
import tqdm
import argparse
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", required=True, type=str)
parser.add_argument("--output", "-o", required=True, type=str)
parser.add_argument("--lexicon", default='align_lexicon.txt', type=str)
args = parser.parse_args()
sil_prob = 0.25
if not os.path.exists(args.lexicon):
print(f"| Warning: lexicon {args.lexicon} not found, downloading ...")
try:
os.system(f"wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1QVeyCpLXLnujBUAickpo-jaSVY-vKLnT' -O {args.lexicon}")
except Exception as e:
print(e)
print(f"| Error downloading {args.lexicon}, please download it from https://drive.google.com/file/d/1QVeyCpLXLnujBUAickpo-jaSVY-vKLnT/view?usp=sharing")
exit(1)
dict = {}
f = open(args.lexicon)
for l in f:
dict[l.split()[0]] = l.split()[2:]
assert l.split()[0] == l.split()[1]
f = open(args.input, 'r')
w_f = open(f'{args.output}.kaldi_phn_sil025', 'w')
w_oov = open(f'{args.output}.kaldi_phn_sil025.oov', 'w')
oov_nums = 0
total_nums = 0
for l in tqdm.tqdm(f):
words = l.strip().replace(" ", "").split("|")
# words = l.strip().upper().split()
words = [w for w in words if w != '']
phones = []
phones.extend(dict['!SIL'])
sample_sil_probs = None
if sil_prob > 0 and len(words) > 1:
sample_sil_probs = np.random.random(len(words) - 1)
for i, w in enumerate(words):
total_nums += 1
if w not in dict:
w = '<UNK>'
oov_nums += 1
w_oov.write(w + '\n')
phones.extend(dict[w])
if (
sample_sil_probs is not None
and i < len(sample_sil_probs)
and sample_sil_probs[i] < sil_prob
):
phones.extend(dict['!SIL'])
phones.extend(dict['!SIL'])
w_f.write(' '.join(phones) + '\n')
w_oov.write(f'{oov_nums}\n')
print(f"OOV rate: {oov_nums}/{total_nums}")
# !!! After processing, use this comand to adjust the SIL
### sed -i 's/SIL_S/SIL/g' your_file
{"SIL": [14, 7], "AE1_I": [5, 2.5], "P_I": [5, 2.5], "T_I": [5, 2.5], "ER0_E": [5, 2.5], "W_B": [5, 2.5], "AH1_I": [5, 2.5], "N_E": [5, 2.5], "M_B": [5, 2.5], "IH1_I": [5, 2.5], "S_I": [5, 2.5], "IH0_I": [5, 2.5], "Z_E": [5, 2.5], "R_B": [5, 2.5], "EY1_I": [5, 2.5], "CH_I": [5, 2.5], "AH0_I": [5, 2.5], "L_E": [5, 2.5], "L_B": [5, 2.5], "N_I": [5, 2.5], "D_E": [5, 2.5], "IH0_B": [5, 2.5], "S_B": [5, 2.5], "R_I": [5, 2.5], "AY1_I": [5, 2.5], "Z_I": [5, 2.5], "V_I": [5, 2.5], "JH_B": [5, 2.5], "T_E": [5, 2.5], "EH1_I": [5, 2.5], "R_E": [5, 2.5], "DH_B": [5, 2.5], "IY0_E": [5, 2.5], "AE1_B": [5, 2.5], "L_I": [5, 2.5], "IY2_E": [5, 2.5], "OW1_I": [5, 2.5], "D_B": [5, 2.5], "AW1_I": [5, 2.5], "UW1_E": [5, 2.5], "AH0_S": [5, 2.5], "HH_B": [5, 2.5], "AA1_I": [5, 2.5], "OW0_E": [5, 2.5], "F_B": [5, 2.5], "JH_I": [5, 2.5], "TH_E": [5, 2.5], "AO1_B": [5, 2.5], "D_I": [5, 2.5], "ER0_I": [5, 2.5], "AH0_B": [5, 2.5], "IY0_I": [5, 2.5], "IH1_B": [5, 2.5], "AA2_I": [5, 2.5], "S_E": [5, 2.5], "T_B": [5, 2.5], "ER1_I": [5, 2.5], "B_B": [5, 2.5], "AY1_E": [5, 2.5], "UH1_I": [5, 2.5], "K_E": [5, 2.5], "AO1_I": [5, 2.5], "W_I": [5, 2.5], "EY1_E": [5, 2.5], "AH1_E": [5, 2.5], "V_E": [5, 2.5], "OW1_B": [5, 2.5], "K_B": [5, 2.5], "TH_I": [5, 2.5], "B_I": [5, 2.5], "P_B": [5, 2.5], "Y_I": [5, 2.5], "UW1_I": [5, 2.5], "IH0_E": [5, 2.5], "IY1_E": [5, 2.5], "K_I": [5, 2.5], "AO2_I": [5, 2.5], "NG_E": [5, 2.5], "ER1_B": [5, 2.5], "TH_B": [5, 2.5], "IY1_I": [5, 2.5], "AE0_I": [5, 2.5], "AH0_E": [5, 2.5], "M_E": [5, 2.5], "N_B": [5, 2.5], "IY1_B": [5, 2.5], "DH_I": [5, 2.5], "G_I": [5, 2.5], "SH_I": [5, 2.5], "SH_B": [5, 2.5], "P_E": [5, 2.5], "AY1_S": [5, 2.5], "AA1_B": [5, 2.5], "EH1_B": [5, 2.5], "IH2_I": [5, 2.5], "AH1_B": [5, 2.5], "F_E": [5, 2.5], "AW1_B": [5, 2.5], "F_I": [5, 2.5], "EH2_I": [5, 2.5], "JH_E": [5, 2.5], "AY2_I": [5, 2.5], "EY2_E": [5, 2.5], "NG_I": [5, 2.5], "CH_E": [5, 2.5], "EY1_B": [5, 2.5], "AA0_B": [5, 2.5], "Y_B": [5, 2.5], "DH_E": [5, 2.5], "IY2_I": [5, 2.5], "V_B": [5, 2.5], "OY1_I": [5, 2.5], "UW0_E": [5, 2.5], "OW1_E": [5, 2.5], "G_B": [5, 2.5], "AE2_B": [5, 2.5], "M_I": [5, 2.5], "SH_E": [5, 2.5], "IH2_B": [5, 2.5], "AW1_E": [5, 2.5], "ZH_I": [5, 2.5], "ER0_S": [5, 2.5], "AY1_B": [5, 2.5], "AA0_I": [5, 2.5], "G_E": [5, 2.5], "EH0_B": [5, 2.5], "SPN_S": [32, 11], "UW2_I": [5, 2.5], "UW0_I": [5, 2.5], "EY2_I": [5, 2.5], "ER1_E": [5, 2.5], "OW2_I": [5, 2.5], "OW0_I": [5, 2.5], "HH_I": [5, 2.5], "B_E": [5, 2.5], "AO1_E": [5, 2.5], "AH2_B": [5, 2.5], "UH2_I": [5, 2.5], "OW1_S": [5, 2.5], "AO2_B": [5, 2.5], "OY1_E": [5, 2.5], "AE2_I": [5, 2.5], "AO0_B": [5, 2.5], "EH2_B": [5, 2.5], "EY1_S": [5, 2.5], "AE0_B": [5, 2.5], "ER0_B": [5, 2.5], "EH0_I": [5, 2.5], "EY0_I": [5, 2.5], "AW2_E": [5, 2.5], "AW2_I": [5, 2.5], "AY0_B": [5, 2.5], "AA2_B": [5, 2.5], "EY0_E": [5, 2.5], "AO0_I": [5, 2.5], "AY0_I": [5, 2.5], "AH2_I": [5, 2.5], "OW2_E": [5, 2.5], "ZH_E": [5, 2.5], "AY2_E": [5, 2.5], "ER2_I": [5, 2.5], "IY2_B": [5, 2.5], "AA1_S": [5, 2.5], "AA1_E": [5, 2.5], "OY0_I": [5, 2.5], "IY0_B": [5, 2.5], "OY2_E": [5, 2.5], "OW2_B": [5, 2.5], "AY0_E": [5, 2.5], "OY2_I": [5, 2.5], "UW1_B": [5, 2.5], "OY0_E": [5, 2.5], "UH0_I": [5, 2.5], "OY1_B": [5, 2.5], "AW0_B": [5, 2.5], "AO1_S": [5, 2.5], "OW0_B": [5, 2.5], "EH1_S": [5, 2.5], "AW0_I": [5, 2.5], "UW0_B": [5, 2.5], "AO2_E": [5, 2.5], "UW2_E": [5, 2.5], "L_S": [5, 2.5], "Z_B": [5, 2.5], "AA2_E": [5, 2.5], "EY0_B": [5, 2.5], "AY2_B": [5, 2.5], "AW0_E": [5, 2.5], "IY1_S": [5, 2.5], "EY2_B": [5, 2.5], "AH1_S": [5, 2.5], "IH2_E": [5, 2.5], "AW2_B": [5, 2.5], "AA0_E": [5, 2.5], "ER2_E": [5, 2.5], "ZH_B": [5, 2.5], "UH1_E": [5, 2.5], "EH1_E": [5, 2.5], "IH1_E": [5, 2.5], "ER1_S": [5, 2.5], "EH2_E": [5, 2.5], "AO0_E": [5, 2.5], "OY1_S": [5, 2.5], "AA_B": [5, 2.5], "AA_E": [5, 2.5], "AA_I": [5, 2.5], "AA_S": [5, 2.5], "AA0_S": [5, 2.5], "AA2_S": [5, 2.5], "AE_B": [5, 2.5], "AE_E": [5, 2.5], "AE_I": [5, 2.5], "AE_S": [5, 2.5], "AE0_E": [5, 2.5], "AE0_S": [5, 2.5], "AE1_E": [5, 2.5], "AE1_S": [5, 2.5], "AE2_E": [5, 2.5], "AE2_S": [5, 2.5], "AH_B": [5, 2.5], "AH_E": [5, 2.5], "AH_I": [5, 2.5], "AH_S": [5, 2.5], "AH2_E": [5, 2.5], "AH2_S": [5, 2.5], "AO_B": [5, 2.5], "AO_E": [5, 2.5], "AO_I": [5, 2.5], "AO_S": [5, 2.5], "AO0_S": [5, 2.5], "AO2_S": [5, 2.5], "AW_B": [5, 2.5], "AW_E": [5, 2.5], "AW_I": [5, 2.5], "AW_S": [5, 2.5], "AW0_S": [5, 2.5], "AW1_S": [5, 2.5], "AW2_S": [5, 2.5], "AY_B": [5, 2.5], "AY_E": [5, 2.5], "AY_I": [5, 2.5], "AY_S": [5, 2.5], "AY0_S": [5, 2.5], "AY2_S": [5, 2.5], "B_S": [5, 2.5], "CH_S": [5, 2.5], "D_S": [5, 2.5], "DH_S": [5, 2.5], "EH_B": [5, 2.5], "EH_E": [5, 2.5], "EH_I": [5, 2.5], "EH_S": [5, 2.5], "EH0_E": [5, 2.5], "EH0_S": [5, 2.5], "EH2_S": [5, 2.5], "ER_B": [5, 2.5], "ER_E": [5, 2.5], "ER_I": [5, 2.5], "ER_S": [5, 2.5], "ER2_B": [5, 2.5], "ER2_S": [5, 2.5], "EY_B": [5, 2.5], "EY_E": [5, 2.5], "EY_I": [5, 2.5], "EY_S": [5, 2.5], "EY0_S": [5, 2.5], "EY2_S": [5, 2.5], "F_S": [5, 2.5], "G_S": [5, 2.5], "HH_E": [5, 2.5], "HH_S": [5, 2.5], "IH_B": [5, 2.5], "IH_E": [5, 2.5], "IH_I": [5, 2.5], "IH_S": [5, 2.5], "IH0_S": [5, 2.5], "IH1_S": [5, 2.5], "IH2_S": [5, 2.5], "IY_B": [5, 2.5], "IY_E": [5, 2.5], "IY_I": [5, 2.5], "IY_S": [5, 2.5], "IY0_S": [5, 2.5], "IY2_S": [5, 2.5], "JH_S": [5, 2.5], "K_S": [5, 2.5], "M_S": [5, 2.5], "N_S": [5, 2.5], "NG_B": [5, 2.5], "NG_S": [5, 2.5], "OW_B": [5, 2.5], "OW_E": [5, 2.5], "OW_I": [5, 2.5], "OW_S": [5, 2.5], "OW0_S": [5, 2.5], "OW2_S": [5, 2.5], "OY_B": [5, 2.5], "OY_E": [5, 2.5], "OY_I": [5, 2.5], "OY_S": [5, 2.5], "OY0_B": [5, 2.5], "OY0_S": [5, 2.5], "OY2_B": [5, 2.5], "OY2_S": [5, 2.5], "P_S": [5, 2.5], "R_S": [5, 2.5], "S_S": [5, 2.5], "SH_S": [5, 2.5], "T_S": [5, 2.5], "TH_S": [5, 2.5], "UH_B": [5, 2.5], "UH_E": [5, 2.5], "UH_I": [5, 2.5], "UH_S": [5, 2.5], "UH0_B": [5, 2.5], "UH0_E": [5, 2.5], "UH0_S": [5, 2.5], "UH1_B": [5, 2.5], "UH1_S": [5, 2.5], "UH2_B": [5, 2.5], "UH2_E": [5, 2.5], "UH2_S": [5, 2.5], "UW_B": [5, 2.5], "UW_E": [5, 2.5], "UW_I": [5, 2.5], "UW_S": [5, 2.5], "UW0_S": [5, 2.5], "UW1_S": [5, 2.5], "UW2_B": [5, 2.5], "UW2_S": [5, 2.5], "V_S": [5, 2.5], "W_E": [5, 2.5], "W_S": [5, 2.5], "Y_E": [5, 2.5], "Y_S": [5, 2.5], "Z_S": [5, 2.5], "ZH_S": [5, 2.5]}
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import sys, json, tqdm
import numpy as np
input_file = sys.argv[1]
mean_and_std_file = sys.argv[2]
out_file = sys.argv[3]
mean_and_std = json.load(open(mean_and_std_file, 'r'))
with open(input_file, 'r') as f, open(out_file, 'w') as w:
for line in tqdm.tqdm(f):
l = line.split()
new_l = []
for phn in l:
if phn not in mean_and_std:
mean_and_std[phn] = [5, 2.5]
print(f'unk phone {phn}')
n = max(1, round(np.random.normal(loc=mean_and_std[phn][0], scale=mean_and_std[phn][1])))
new_l.extend([phn] * int(n))
minus = 0
while len(new_l) >= 4375:
minus += 1
new_l = []
for phn in l:
n = max(1, round(mean_and_std[phn][0] - minus))
new_l.extend([phn] * n)
print(f"too long line try minus {minus}")
w.write(' '.join(new_l)+'\n')
#!/bin/bash
[ ${PWD##*/} != SpeechLM ] && echo "Error: dir not match! Switch to SpeechLM/ and run it again!" && exit 1
[ $# -lt 1 ] && echo "Usage: $0 <lang> [root=${PWD}/dataset/CommonVoice/v4]" && exit 0
cwd=${PWD}
src=${PWD}/speechlm/data_process
lang=$1
root=$2
[ -z $root ] && root="${PWD}/dataset/CommonVoice/v4"
set -e -o pipefail -u
### step1, convert mp3 to wav
cd $root/en && mkdir -p wav
cut -f2 validated.tsv | sed '1d' | sed "s|^|${root}/en/clips/|" > validated.id
for i in $(seq 0 39); do
echo extracting $i;
python $src/covost2/mp3_to_wav.py -i validated.id -n 40 -r $i &
done
wait
cd $cwd
### step2, manifest
datadir="$root/en/en-$lang" && mkdir -p $datadir && cd $datadir
python /mnt/default/v-ziqzhang/code/stpretrain_scripts/data_process/covost2/prepare_covost_data.py --data-root $root --src-lang en --tgt-lang $lang --vocab-type char
mv ../*en_${lang}.* ./
# adjust config_base_en${lang}.yaml
echo "bpe_tokenizer:" > config_base_en${lang}.yaml
echo " bpe: sentencepiece" >> config_base_en${lang}.yaml
echo " sentencepiece_model: spm_char_st_en_de.model" >> config_base_en${lang}.yaml
echo "" >> config_base_en${lang}.yaml
echo "shuffle: false" >> config_base_en${lang}.yaml
echo "use_audio_input: true" >> config_base_en${lang}.yaml
echo "use_sample_rate: 16000" >> config_base_en${lang}.yaml
echo "standardize_audio: false" >> config_base_en${lang}.yaml
echo "vocab_filename: spm_char_st_en_de.txt" >> config_base_en${lang}.yaml
echo "" >> config_base_en${lang}.yaml
echo "# required by speech_to_text task but never used" >> config_base_en${lang}.yaml
echo "input_channels: 1" >> config_base_en${lang}.yaml
echo "input_feat_per_channel: 1" >> config_base_en${lang}.yaml
echo "" >> config_base_en${lang}.yaml
# adjust config_large_en${lang}.yaml
cat config_base_en${lang}.yaml | sed "s|standardize_audio: false|standardize_audio: true|" > config_large_en${lang}.yaml
#!/bin/bash
[ ${PWD##*/} != SpeechLM ] && echo "Error: dir not match! Switch to SpeechLM/ and run it again!" && exit 1
cwd=${PWD}
src=${PWD}/speechlm/data_process
set -e
mkdir -p dataset/LibriLM/phone_unit/tmp && cd dataset/LibriLM
if [ ! -f librispeech-lm-norm.txt ]; then
echo "--------------------------------------------------------------------------------------"
echo "--------Downloading and unpacking librispeech-lm-norm.txt ..."
echo "--------------------------------------------------------------------------------------"
wget -c https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
gzip -d librispeech-lm-norm.txt.gz
fi
# head -1000000 librispeech-lm-norm.txt > phone_unit/tmp/librispeech-lm-norm.txt
cd phone_unit/
echo "--------------------------------------------------------------------------------------"
echo "--------Tokenize the text..."
echo "--------------------------------------------------------------------------------------"
cat ../librispeech-lm-norm.txt | sed '1d' | python $src/wrd2ltr.py > tmp/librilm.ltr
echo "--------------------------------------------------------------------------------------"
echo "--------Tokenize the text to the kaldi-style phonemes ..."
echo "--------------------------------------------------------------------------------------"
python $src/phoneme_tokenizer/ltr2kaldi_phn_sil025.py -i tmp/librilm.ltr -o tmp/librilm
cat tmp/librilm.kaldi_phn_sil025 | sed 's/SIL_S/SIL/g' > tmp/librilm.phn
echo "--------------------------------------------------------------------------------------"
echo "--------Filter too long samples and up-sample phonemes ..."
echo "--------------------------------------------------------------------------------------"
python $src/filter_paireddata_by_len.py -i tmp/librilm -o tmp/librilm_l2k -s phn -t ltr -m 2000
python $src/phoneme_tokenizer/repeat_withou_insert_sil_less_4375.py \
tmp/librilm_l2k.phn \
$src/phoneme_tokenizer/mean5_and_std25_sil14_spn32.dict \
tmp/librilm_l2k_upsample.phn
mv tmp/librilm_l2k.ltr tmp/librilm_l2k_upsample.ltr
python $src/filter_paireddata_by_len.py -i tmp/librilm_l2k_upsample -o train_text.phn-ltr -s phn -t ltr -m 2800
### the max-length is set to filter the data, considering the batch size (in Large setting, 900,000/320 = 2812 tokens in a batch).
echo "--------------------------------------------------------------------------------------"
echo "--------Create binary files ..."
echo "--------------------------------------------------------------------------------------"
[ ! -f bin-idx/dict.phn.txt ] && echo "dict ${cwd}/dataset/LibriLM/bin-idx/dict.phn.txt not found!" && exit 1
[ ! -f bin-idx/dict.ltr.txt ] && echo "dict ${cwd}/dataset/LibriLM/bin-idx/dict.ltr.txt not found!" && exit 1
bash $src/txt2idx.sh train_text.phn-ltr.phn bin-idx bin-idx/dict.phn.txt
bash $src/txt2idx.sh train_text.phn-ltr.ltr bin-idx bin-idx/dict.ltr.txt
rm -r tmp
cd -
echo "--------------------------------------------------------------------------------------"
echo "--------Done! files are in ${PWD}/dataset/LibriLM"
echo "--------------------------------------------------------------------------------------"
#!/bin/bash
[ $# -lt 3 ] && echo "Usage: $0 <input-text> <outdir> <DICT> <suffix>" && exit 0
input=$1
outdir=$2
DICT=$3
suffix=$4
outname=${input##*/}
outname=${outname%.txt*}
[ -z $input ] && echo "You must specify a source file" && exit 1
[ -z $DICT ] && echo "No dict was specified!" && exit 1
[ -z $outdir ] && outdir=${input%/*}
[ -z $outdir ] && outdir="."
[ ! -d $outdir ] && mkdir -p $outdir
echo "------------------------------- creating idx/bin--------------------------------------------"
echo "$input --> $outdir/${outname}${suffix}.idx"
fairseq-preprocess \
--only-source \
--trainpref $input \
--destdir $outdir \
--thresholdsrc 0 \
--srcdict ${DICT} \
--workers 40
mv $outdir/train.idx $outdir/${outname}${suffix}.idx
mv $outdir/train.bin $outdir/${outname}${suffix}.bin
echo "----------------------------------- done --------------------------------------------"
import sys
def main():
for line in sys.stdin:
line = line.replace("<unk>", "")
line = " ".join(line.strip().split())
line = line.replace(" ", "|").upper() + "|"
print(" ".join(line))
if __name__ == "__main__":
main()
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
"""
Modified form: https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq_cli/generate.py
"""
import ast
import logging
import math
import os
import sys
from argparse import Namespace
from itertools import chain
import numpy as np
import torch
from omegaconf import DictConfig
from fairseq import checkpoint_utils, options, scoring, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
def main(cfg: DictConfig):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
assert cfg.common_eval.path is not None, "--path required for generation!"
assert (
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw"
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
if cfg.common_eval.results_path is not None:
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
output_path = os.path.join(
cfg.common_eval.results_path,
"generate-{}.txt".format(cfg.dataset.gen_subset),
)
with open(output_path, "w", buffering=1, encoding="utf-8") as h:
return _main(cfg, h)
else:
return _main(cfg, sys.stdout)
def get_symbols_to_strip_from_output(generator):
if hasattr(generator, "symbols_to_strip_from_output"):
return generator.symbols_to_strip_from_output
else:
return {generator.eos}
def _main(cfg: DictConfig, output_file):
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=output_file,
)
logger = logging.getLogger("fairseq_cli.generate")
utils.import_user_module(cfg.common)
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.max_tokens = 12000
logger.info(cfg)
# Fix seed for stochastic decoding
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
# Load dataset splits
task = tasks.setup_task(cfg.task)
# Set dictionaries
try:
src_dict = getattr(task, "source_dictionary", None)
except NotImplementedError:
src_dict = None
tgt_dict = task.target_dictionary
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
# Load ensemble
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
task=task,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
if cfg.generation.lm_path is not None:
overrides["data"] = cfg.task.data
try:
lms, _ = checkpoint_utils.load_model_ensemble(
[cfg.generation.lm_path], arg_overrides=overrides, task=None
)
except:
logger.warning(
f"Failed to load language model! Please make sure that the language model dict is the same "
f"as target dict and is located in the data dir ({cfg.task.data})"
)
raise
assert len(lms) == 1
else:
lms = [None]
# Optimize ensemble for generation
for model in chain(models, lms):
if model is None:
continue
if cfg.common.fp16:
model.half()
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(cfg)
def _fp_convert_sample(sample):
def apply_half(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.half)
return t
def apply_bfloat16(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.bfloat16)
return t
if cfg.common.fp16:
sample = utils.apply_to_sample(apply_half, sample)
if cfg.common.bf16:
sample = utils.apply_to_sample(apply_bfloat16, sample)
return sample
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(cfg.dataset.gen_subset),
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(), *[m.max_positions() for m in models]
),
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_shards=cfg.distributed_training.distributed_world_size,
shard_id=cfg.distributed_training.distributed_rank,
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
)
# Initialize generator
gen_timer = StopwatchMeter()
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
generator = task.build_generator(
models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
# Handle tokenization and BPE
tokenizer = task.build_tokenizer(cfg.tokenizer)
bpe = task.build_bpe(cfg.bpe)
def decode_fn(x):
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
scorer = scoring.build_scorer(cfg.scoring, None)
num_sentences = 0
has_target = True
wps_meter = TimeMeter()
for sample in progress:
sample = utils.move_to_cuda(sample) if use_cuda else sample
sample = _fp_convert_sample(sample)
if "net_input" not in sample:
continue
prefix_tokens = None
if cfg.generation.prefix_size > 0:
prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
constraints = None
if "constraints" in sample:
constraints = sample["constraints"]
gen_timer.start()
hypos = task.inference_step(
generator,
models[0],
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
)
num_generated_tokens = sum(len(h["unit"]) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample["id"].tolist()):
has_target = sample["target"] is not None
# Remove padding
if "src_tokens" in sample["net_input"]:
src_tokens = utils.strip_pad(
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
).cpu()
else:
src_tokens = None
target_tokens = None
if has_target:
target_tokens = (
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).cpu()
)
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
sample_id
)
target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
sample_id
)
else:
if src_dict is not None:
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
else:
src_str = ""
if has_target:
target_str = " ".join(map(str, target_tokens.numpy().tolist()))
src_str = decode_fn(src_str)
if not cfg.common_eval.quiet:
if src_dict is not None:
print("S-{}\t{}".format(sample_id, src_str), file=output_file)
if has_target:
print("T-{}\t{}".format(sample_id, target_str), file=output_file)
# Process top predictions
j = 0
hypo = hypos[i]
hypo_tokens = hypo["unit"].int().cpu()
hypo_str = " ".join(map(str, hypo_tokens.numpy().tolist()))
alignment = None
detok_hypo_str = hypo_str
# add duration prediction
hypo_duration = " ".join(map(str, hypo["duration"].int().cpu().numpy().tolist()))
hypo_fa_src_str = src_dict.string(hypo["fa_src"].cpu().numpy(), cfg.common_eval.post_process)
# hypo_fa_src_str = " ".join(map(str, hypo["fa_src"].int().cpu().numpy() - 4))
if not cfg.common_eval.quiet:
# score = hypo["score"] / math.log(2) # convert to base 2
score = 0.00
# original hypothesis (after tokenization and BPE)
# print(
# "H-{}\t{}\t{}".format(sample_id, score, hypo_str),
# file=output_file,
# )
# detokenized hypothesis
print(
"D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str),
file=output_file,
)
# duration prediction
print(
"L-{}\t{}\t{}".format(sample_id, score, hypo_duration),
file=output_file,
)
# force-aligned upsampled src-tokens
print(
"U-{}\t{}\t{}".format(sample_id, score, hypo_fa_src_str),
file=output_file,
)
# print(
# "P-{}\t{}".format(
# sample_id,
# " ".join(
# map(
# lambda x: "{:.4f}".format(x),
# # convert from base e to base 2
# hypo["positional_scores"]
# .div_(math.log(2))
# .tolist(),
# )
# ),
# ),
# file=output_file,
# )
if cfg.generation.print_alignment == "hard":
print(
"A-{}\t{}".format(
sample_id,
" ".join(
[
"{}-{}".format(src_idx, tgt_idx)
for src_idx, tgt_idx in alignment
]
),
),
file=output_file,
)
if cfg.generation.print_alignment == "soft":
print(
"A-{}\t{}".format(
sample_id,
" ".join(
[",".join(src_probs) for src_probs in alignment]
),
),
file=output_file,
)
# Score only the top hypothesis
if has_target and j == 0:
if hasattr(scorer, "add_string"):
scorer.add_string(target_str, detok_hypo_str)
else:
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(num_generated_tokens)
progress.log({"wps": round(wps_meter.avg)})
num_sentences += (
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
)
logger.info("NOTE: hypothesis and token scores are output in base 2")
logger.info(
"Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
num_sentences,
gen_timer.n,
gen_timer.sum,
num_sentences / gen_timer.sum,
1.0 / gen_timer.avg,
)
)
if has_target:
if cfg.bpe and not cfg.generation.sacrebleu:
if cfg.common_eval.post_process:
logger.warning(
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
)
else:
logger.warning(
"If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization"
)
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on
print(
"Generate {} with beam={}: {}".format(
cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
),
file=output_file,
)
return scorer
def cli_main():
parser = options.get_generation_parser()
# TODO: replace this workaround with refactoring of `AudioPretraining`
parser.add_argument(
"--arch",
"-a",
metavar="ARCH",
default="wav2vec2",
help="Model architecture. For constructing tasks that rely on "
"model args (e.g. `AudioPretraining`)",
)
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
"""
Modified form: https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/examples/speech_recognition/new/infer.py
1. add "utils.import_user_module(cfg.common)" so that usr-dir can be loaded
"""
import ast
import hashlib
import logging
import os
import shutil
import sys
from dataclasses import dataclass, field, is_dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import editdistance
import torch
import torch.distributed as dist
import examples
from examples.speech_recognition.new.decoders.decoder_config import (
DecoderConfig,
FlashlightDecoderConfig,
)
from examples.speech_recognition.new.decoders.decoder import Decoder
from fairseq import checkpoint_utils, distributed_utils, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.dataclass.configs import (
CheckpointConfig,
CommonConfig,
CommonEvalConfig,
DatasetConfig,
DistributedTrainingConfig,
FairseqDataclass,
)
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.logging.progress_bar import BaseProgressBar
from fairseq.models.fairseq_model import FairseqModel
from omegaconf import OmegaConf
import hydra
from hydra.core.config_store import ConfigStore
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
config_path = Path(examples.speech_recognition.new.__path__[0]).resolve() / "conf"
@dataclass
class DecodingConfig(DecoderConfig, FlashlightDecoderConfig):
unique_wer_file: bool = field(
default=False,
metadata={"help": "If set, use a unique file for storing WER"},
)
results_path: Optional[str] = field(
default=None,
metadata={
"help": "If set, write hypothesis and reference sentences into this directory"
},
)
@dataclass
class InferConfig(FairseqDataclass):
task: Any = None
decoding: DecodingConfig = DecodingConfig()
common: CommonConfig = CommonConfig()
common_eval: CommonEvalConfig = CommonEvalConfig()
checkpoint: CheckpointConfig = CheckpointConfig()
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
dataset: DatasetConfig = DatasetConfig()
is_ax: bool = field(
default=False,
metadata={
"help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
},
)
def reset_logging():
root = logging.getLogger()
for handler in root.handlers:
root.removeHandler(handler)
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
root.addHandler(handler)
class InferenceProcessor:
cfg: InferConfig
def __init__(self, cfg: InferConfig) -> None:
self.cfg = cfg
self.task = tasks.setup_task(cfg.task)
models, saved_cfg = self.load_model_ensemble()
self.models = models
self.saved_cfg = saved_cfg
self.tgt_dict = self.task.target_dictionary
self.task.load_dataset(
self.cfg.dataset.gen_subset,
task_cfg=saved_cfg.task,
)
self.generator = Decoder(cfg.decoding, self.tgt_dict)
self.gen_timer = StopwatchMeter()
self.wps_meter = TimeMeter()
self.num_sentences = 0
self.total_errors = 0
self.total_length = 0
self.hypo_words_file = None
self.hypo_units_file = None
self.ref_words_file = None
self.ref_units_file = None
self.progress_bar = self.build_progress_bar()
def __enter__(self) -> "InferenceProcessor":
if self.cfg.decoding.results_path is not None:
self.hypo_words_file = self.get_res_file("hypo.word")
self.hypo_units_file = self.get_res_file("hypo.units")
self.ref_words_file = self.get_res_file("ref.word")
self.ref_units_file = self.get_res_file("ref.units")
return self
def __exit__(self, *exc) -> bool:
if self.cfg.decoding.results_path is not None:
self.hypo_words_file.close()
self.hypo_units_file.close()
self.ref_words_file.close()
self.ref_units_file.close()
return False
def __iter__(self) -> Any:
for sample in self.progress_bar:
if not self.cfg.common.cpu:
sample = utils.move_to_cuda(sample)
# Happens on the last batch.
if "net_input" not in sample:
continue
yield sample
def log(self, *args, **kwargs):
self.progress_bar.log(*args, **kwargs)
def print(self, *args, **kwargs):
self.progress_bar.print(*args, **kwargs)
def get_res_file(self, fname: str) -> None:
fname = os.path.join(self.cfg.decoding.results_path, fname)
if self.data_parallel_world_size > 1:
fname = f"{fname}.{self.data_parallel_rank}"
return open(fname, "w", buffering=1)
def merge_shards(self) -> None:
"""Merges all shard files into shard 0, then removes shard suffix."""
shard_id = self.data_parallel_rank
num_shards = self.data_parallel_world_size
if self.data_parallel_world_size > 1:
def merge_shards_with_root(fname: str) -> None:
fname = os.path.join(self.cfg.decoding.results_path, fname)
logger.info("Merging %s on shard %d", fname, shard_id)
base_fpath = Path(f"{fname}.0")
with open(base_fpath, "a") as out_file:
for s in range(1, num_shards):
shard_fpath = Path(f"{fname}.{s}")
with open(shard_fpath, "r") as in_file:
for line in in_file:
out_file.write(line)
shard_fpath.unlink()
shutil.move(f"{fname}.0", fname)
dist.barrier() # ensure all shards finished writing
if shard_id == (0 % num_shards):
merge_shards_with_root("hypo.word")
if shard_id == (1 % num_shards):
merge_shards_with_root("hypo.units")
if shard_id == (2 % num_shards):
merge_shards_with_root("ref.word")
if shard_id == (3 % num_shards):
merge_shards_with_root("ref.units")
dist.barrier()
def optimize_model(self, model: FairseqModel) -> None:
model.make_generation_fast_()
if self.cfg.common.fp16:
model.half()
if not self.cfg.common.cpu:
model.cuda()
def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]:
arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(self.cfg.common_eval.path, separator="\\"),
arg_overrides=arg_overrides,
task=self.task,
suffix=self.cfg.checkpoint.checkpoint_suffix,
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=self.cfg.checkpoint.checkpoint_shard_count,
)
for model in models:
self.optimize_model(model)
return models, saved_cfg
def get_dataset_itr(self, disable_iterator_cache: bool = False) -> None:
return self.task.get_batch_iterator(
dataset=self.task.dataset(self.cfg.dataset.gen_subset),
max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
seed=self.cfg.common.seed,
num_shards=self.data_parallel_world_size,
shard_id=self.data_parallel_rank,
num_workers=self.cfg.dataset.num_workers,
data_buffer_size=self.cfg.dataset.data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
).next_epoch_itr(shuffle=False)
def build_progress_bar(
self,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
default_log_format: str = "tqdm",
) -> BaseProgressBar:
return progress_bar.progress_bar(
iterator=self.get_dataset_itr(),
log_format=self.cfg.common.log_format,
log_interval=self.cfg.common.log_interval,
epoch=epoch,
prefix=prefix,
tensorboard_logdir=self.cfg.common.tensorboard_logdir,
default_log_format=default_log_format,
)
@property
def data_parallel_world_size(self):
if self.cfg.distributed_training.distributed_world_size == 1:
return 1
return distributed_utils.get_data_parallel_world_size()
@property
def data_parallel_rank(self):
if self.cfg.distributed_training.distributed_world_size == 1:
return 0
return distributed_utils.get_data_parallel_rank()
def process_sentence(
self,
sample: Dict[str, Any],
hypo: Dict[str, Any],
sid: int,
batch_id: int,
) -> Tuple[int, int]:
speaker = None # Speaker can't be parsed from dataset.
if "target_label" in sample:
toks = sample["target_label"]
else:
toks = sample["target"]
toks = toks[batch_id, :]
# Processes hypothesis.
hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu())
if "words" in hypo:
hyp_words = " ".join(hypo["words"])
else:
hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process)
# Processes target.
target_tokens = utils.strip_pad(toks, self.tgt_dict.pad())
tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu())
tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process)
if self.cfg.decoding.results_path is not None:
print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file)
print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file)
print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file)
print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file)
if not self.cfg.common_eval.quiet:
logger.info(f"HYPO: {hyp_words}")
logger.info(f"REF: {tgt_words}")
logger.info("---------------------")
hyp_words, tgt_words = hyp_words.split(), tgt_words.split()
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def process_sample(self, sample: Dict[str, Any]) -> None:
self.gen_timer.start()
hypos = self.task.inference_step(
generator=self.generator,
models=self.models,
sample=sample,
)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
self.gen_timer.stop(num_generated_tokens)
self.wps_meter.update(num_generated_tokens)
for batch_id, sample_id in enumerate(sample["id"].tolist()):
errs, length = self.process_sentence(
sample=sample,
sid=sample_id,
batch_id=batch_id,
hypo=hypos[batch_id][0],
)
self.total_errors += errs
self.total_length += length
self.log({"wps": round(self.wps_meter.avg)})
if "nsentences" in sample:
self.num_sentences += sample["nsentences"]
else:
self.num_sentences += sample["id"].numel()
def log_generation_time(self) -> None:
logger.info(
"Processed %d sentences (%d tokens) in %.1fs %.2f "
"sentences per second, %.2f tokens per second)",
self.num_sentences,
self.gen_timer.n,
self.gen_timer.sum,
self.num_sentences / (self.gen_timer.sum + 1e-6),
1.0 / (self.gen_timer.avg + 1e-6),
)
def parse_wer(wer_file: Path) -> float:
with open(wer_file, "r") as f:
return float(f.readline().strip().split(" ")[1])
def get_wer_file(cfg: InferConfig) -> Path:
"""Hashes the decoding parameters to a unique file ID."""
base_path = "wer"
if cfg.decoding.results_path is not None:
base_path = os.path.join(cfg.decoding.results_path, base_path)
if cfg.decoding.unique_wer_file:
yaml_str = OmegaConf.to_yaml(cfg.decoding)
fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16)
return Path(f"{base_path}.{fid % 1000000}")
else:
return Path(base_path)
def main(cfg: InferConfig) -> float:
"""Entry point for main processing logic.
Args:
cfg: The inferance configuration to use.
wer: Optional shared memory pointer for returning the WER. If not None,
the final WER value will be written here instead of being returned.
Returns:
The final WER if `wer` is None, otherwise None.
"""
utils.import_user_module(cfg.common)
yaml_str, wer_file = OmegaConf.to_yaml(cfg.decoding), get_wer_file(cfg)
# Validates the provided configuration.
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.max_tokens = 4000000
if not cfg.common.cpu and not torch.cuda.is_available():
raise ValueError("CUDA not found; set `cpu=True` to run without CUDA")
logger.info(cfg.common_eval.path)
with InferenceProcessor(cfg) as processor:
for sample in processor:
processor.process_sample(sample)
processor.log_generation_time()
if cfg.decoding.results_path is not None:
processor.merge_shards()
errs_t, leng_t = processor.total_errors, processor.total_length
if cfg.common.cpu:
logger.warning("Merging WER requires CUDA.")
elif processor.data_parallel_world_size > 1:
stats = torch.LongTensor([errs_t, leng_t]).cuda()
dist.all_reduce(stats, op=dist.ReduceOp.SUM)
errs_t, leng_t = stats[0].item(), stats[1].item()
wer = errs_t * 100.0 / leng_t
if distributed_utils.is_master(cfg.distributed_training):
with open(wer_file, "w") as f:
f.write(
(
f"WER: {wer}\n"
f"err / num_ref_words = {errs_t} / {leng_t}\n\n"
f"{yaml_str}"
)
)
return wer
@hydra.main(config_path=config_path, config_name="infer")
def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]:
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
cfg = OmegaConf.create(container)
OmegaConf.set_struct(cfg, True)
if cfg.common.reset_logging:
reset_logging()
utils.import_user_module(cfg.common)
# logger.info("Config:\n%s", OmegaConf.to_yaml(cfg))
wer = float("inf")
try:
if cfg.common.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(cfg, main)
else:
distributed_utils.call_main(cfg, main)
wer = parse_wer(get_wer_file(cfg))
except BaseException as e: # pylint: disable=broad-except
if not cfg.common.suppress_crashes:
raise
else:
logger.error("Crashed! %s", str(e))
logger.info("Word error rate: %.4f", wer)
if cfg.is_ax:
return wer, None
return wer
def cli_main() -> None:
try:
from hydra._internal.utils import (
get_args,
) # pylint: disable=import-outside-toplevel
cfg_name = get_args().config_name or "infer"
except ImportError:
logger.warning("Failed to get config name from hydra args")
cfg_name = "infer"
cs = ConfigStore.instance()
cs.store(name=cfg_name, node=InferConfig)
for k in InferConfig.__dataclass_fields__:
if is_dataclass(InferConfig.__dataclass_fields__[k].type):
v = InferConfig.__dataclass_fields__[k].default
cs.store(name=k, node=v)
hydra_main() # pylint: disable=no-value-for-parameter
if __name__ == "__main__":
cli_main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment