Commit c394d7d1 authored by “change”'s avatar “change”
Browse files

init

parents
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import gc
import os.path as osp
import warnings
from collections import deque, namedtuple
from typing import Any, Dict, Tuple
import numpy as np
import torch
from fairseq import tasks
from fairseq.data.dictionary import Dictionary
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models.fairseq_model import FairseqModel
from fairseq.utils import apply_to_sample
from omegaconf import open_dict, OmegaConf
from typing import List
from .decoder_config import FlashlightDecoderConfig
from .base_decoder import BaseDecoder
try:
from flashlight.lib.text.decoder import (
LM,
CriterionType,
DecodeResult,
KenLM,
LexiconDecoder,
LexiconDecoderOptions,
LexiconFreeDecoder,
LexiconFreeDecoderOptions,
LMState,
SmearingMode,
Trie,
)
from flashlight.lib.text.dictionary import create_word_dict, load_words
except ImportError:
warnings.warn(
"flashlight python bindings are required to use this functionality. "
"Please install from "
"https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
)
LM = object
LMState = object
class KenLMDecoder(BaseDecoder):
def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None:
super().__init__(tgt_dict)
self.nbest = cfg.nbest
self.unitlm = cfg.unitlm
if cfg.lexicon:
self.lexicon = load_words(cfg.lexicon)
self.word_dict = create_word_dict(self.lexicon)
self.unk_word = self.word_dict.get_index("<unk>")
self.lm = KenLM(cfg.lmpath, self.word_dict)
self.trie = Trie(self.vocab_size, self.silence)
start_state = self.lm.start(False)
for word, spellings in self.lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = self.lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{word} {spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = LexiconDecoderOptions(
beam_size=cfg.beam,
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
beam_threshold=cfg.beamthreshold,
lm_weight=cfg.lmweight,
word_score=cfg.wordscore,
unk_score=cfg.unkweight,
sil_score=cfg.silweight,
log_add=False,
criterion_type=CriterionType.CTC,
)
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
[],
self.unitlm,
)
else:
assert self.unitlm, "Lexicon-free decoding requires unit LM"
d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
self.lm = KenLM(cfg.lmpath, self.word_dict)
self.decoder_opts = LexiconFreeDecoderOptions(
beam_size=cfg.beam,
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
beam_threshold=cfg.beamthreshold,
lm_weight=cfg.lmweight,
sil_score=cfg.silweight,
log_add=False,
criterion_type=CriterionType.CTC,
)
self.decoder = LexiconFreeDecoder(
self.decoder_opts, self.lm, self.silence, self.blank, []
)
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.
Parameters
----------
token_idxs : List[int]
IDs of decoded tokens.
Returns
-------
List[int]
Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank:
continue
if i == 0 or token_idx != token_idxs[i-1]:
timesteps.append(i)
return timesteps
def decode(
self,
emissions: torch.FloatTensor,
) -> List[List[Dict[str, torch.LongTensor]]]:
B, T, N = emissions.size()
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append(
[
{
"tokens": self.get_tokens(result.tokens),
"score": result.score,
"timesteps": self.get_timesteps(result.tokens),
"words": [
self.word_dict.get_entry(x) for x in result.words if x >= 0
],
}
for result in nbest_results
]
)
return hypos
FairseqLMState = namedtuple(
"FairseqLMState",
[
"prefix",
"incremental_state",
"probs",
],
)
class FairseqLM(LM):
def __init__(self, dictionary: Dictionary, model: FairseqModel) -> None:
super().__init__()
self.dictionary = dictionary
self.model = model
self.unk = self.dictionary.unk()
self.save_incremental = False # this currently does not work properly
self.max_cache = 20_000
if torch.cuda.is_available():
model.cuda()
model.eval()
model.make_generation_fast_()
self.states = {}
self.stateq = deque()
def start(self, start_with_nothing: bool) -> LMState:
state = LMState()
prefix = torch.LongTensor([[self.dictionary.eos()]])
incremental_state = {} if self.save_incremental else None
with torch.no_grad():
res = self.model(prefix.cuda(), incremental_state=incremental_state)
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
if incremental_state is not None:
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
self.states[state] = FairseqLMState(
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
)
self.stateq.append(state)
return state
def score(
self,
state: LMState,
token_index: int,
no_cache: bool = False,
) -> Tuple[LMState, int]:
"""
Evaluate language model based on the current lm state and new word
Parameters:
-----------
state: current lm state
token_index: index of the word
(can be lexicon index then you should store inside LM the
mapping between indices of lexicon and lm, or lm index of a word)
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
curr_state = self.states[state]
def trim_cache(targ_size: int) -> None:
while len(self.stateq) > targ_size:
rem_k = self.stateq.popleft()
rem_st = self.states[rem_k]
rem_st = FairseqLMState(rem_st.prefix, None, None)
self.states[rem_k] = rem_st
if curr_state.probs is None:
new_incremental_state = (
curr_state.incremental_state.copy()
if curr_state.incremental_state is not None
else None
)
with torch.no_grad():
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cuda(), new_incremental_state
)
elif self.save_incremental:
new_incremental_state = {}
res = self.model(
torch.from_numpy(curr_state.prefix).cuda(),
incremental_state=new_incremental_state,
)
probs = self.model.get_normalized_probs(
res, log_probs=True, sample=None
)
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cpu(), new_incremental_state
)
curr_state = FairseqLMState(
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
)
if not no_cache:
self.states[state] = curr_state
self.stateq.append(state)
score = curr_state.probs[token_index].item()
trim_cache(self.max_cache)
outstate = state.child(token_index)
if outstate not in self.states and not no_cache:
prefix = np.concatenate(
[curr_state.prefix, torch.LongTensor([[token_index]])], -1
)
incr_state = curr_state.incremental_state
self.states[outstate] = FairseqLMState(prefix, incr_state, None)
if token_index == self.unk:
score = float("-inf")
return outstate, score
def finish(self, state: LMState) -> Tuple[LMState, int]:
"""
Evaluate eos for language model based on the current lm state
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
return self.score(state, self.dictionary.eos())
def empty_cache(self) -> None:
self.states = {}
self.stateq = deque()
gc.collect()
class FairseqLMDecoder(BaseDecoder):
def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None:
super().__init__(tgt_dict)
self.nbest = cfg.nbest
self.unitlm = cfg.unitlm
self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None
self.idx_to_wrd = {}
checkpoint = torch.load(cfg.lmpath, map_location="cpu")
if "cfg" in checkpoint and checkpoint["cfg"] is not None:
lm_args = checkpoint["cfg"]
else:
lm_args = convert_namespace_to_omegaconf(checkpoint["args"])
if not OmegaConf.is_dict(lm_args):
lm_args = OmegaConf.create(lm_args)
with open_dict(lm_args.task):
lm_args.task.data = osp.dirname(cfg.lmpath)
task = tasks.setup_task(lm_args.task)
model = task.build_model(lm_args.model)
model.load_state_dict(checkpoint["model"], strict=False)
self.trie = Trie(self.vocab_size, self.silence)
self.word_dict = task.dictionary
self.unk_word = self.word_dict.unk()
self.lm = FairseqLM(self.word_dict, model)
if self.lexicon:
start_state = self.lm.start(False)
for i, (word, spellings) in enumerate(self.lexicon.items()):
if self.unitlm:
word_idx = i
self.idx_to_wrd[i] = word
score = 0
else:
word_idx = self.word_dict.index(word)
_, score = self.lm.score(start_state, word_idx, no_cache=True)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = LexiconDecoderOptions(
beam_size=cfg.beam,
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
beam_threshold=cfg.beamthreshold,
lm_weight=cfg.lmweight,
word_score=cfg.wordscore,
unk_score=cfg.unkweight,
sil_score=cfg.silweight,
log_add=False,
criterion_type=CriterionType.CTC,
)
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
[],
self.unitlm,
)
else:
assert self.unitlm, "Lexicon-free decoding requires unit LM"
d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
self.lm = KenLM(cfg.lmpath, self.word_dict)
self.decoder_opts = LexiconFreeDecoderOptions(
beam_size=cfg.beam,
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
beam_threshold=cfg.beamthreshold,
lm_weight=cfg.lmweight,
sil_score=cfg.silweight,
log_add=False,
criterion_type=CriterionType.CTC,
)
self.decoder = LexiconFreeDecoder(
self.decoder_opts, self.lm, self.silence, self.blank, []
)
def decode(
self,
emissions: torch.FloatTensor,
) -> List[List[Dict[str, torch.LongTensor]]]:
B, T, N = emissions.size()
hypos = []
def make_hypo(result: DecodeResult) -> Dict[str, Any]:
hypo = {
"tokens": self.get_tokens(result.tokens),
"score": result.score,
}
if self.lexicon:
hypo["words"] = [
self.idx_to_wrd[x] if self.unitlm else self.word_dict[x]
for x in result.words
if x >= 0
]
return hypo
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append([make_hypo(result) for result in nbest_results])
self.lm.empty_cache()
return hypos
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from typing import List, Dict
from .base_decoder import BaseDecoder
class ViterbiDecoder(BaseDecoder):
def decode(
self,
emissions: torch.FloatTensor,
) -> List[List[Dict[str, torch.LongTensor]]]:
def get_pred(e):
toks = e.argmax(dim=-1).unique_consecutive()
return toks[toks != self.blank]
return [[{"tokens": get_pred(x), "score": 0}] for x in emissions]
#!/usr/bin/env python -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import hashlib
import logging
import os
import shutil
import sys
from dataclasses import dataclass, field, is_dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import editdistance
import torch
import torch.distributed as dist
from examples.speech_recognition.new.decoders.decoder_config import (
DecoderConfig,
FlashlightDecoderConfig,
)
from examples.speech_recognition.new.decoders.decoder import Decoder
from fairseq import checkpoint_utils, distributed_utils, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.dataclass.configs import (
CheckpointConfig,
CommonConfig,
CommonEvalConfig,
DatasetConfig,
DistributedTrainingConfig,
FairseqDataclass,
)
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.logging.progress_bar import BaseProgressBar
from fairseq.models.fairseq_model import FairseqModel
from omegaconf import OmegaConf
import hydra
from hydra.core.config_store import ConfigStore
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
config_path = Path(__file__).resolve().parent / "conf"
@dataclass
class DecodingConfig(DecoderConfig, FlashlightDecoderConfig):
unique_wer_file: bool = field(
default=False,
metadata={"help": "If set, use a unique file for storing WER"},
)
results_path: Optional[str] = field(
default=None,
metadata={
"help": "If set, write hypothesis and reference sentences into this directory"
},
)
@dataclass
class InferConfig(FairseqDataclass):
task: Any = None
decoding: DecodingConfig = DecodingConfig()
common: CommonConfig = CommonConfig()
common_eval: CommonEvalConfig = CommonEvalConfig()
checkpoint: CheckpointConfig = CheckpointConfig()
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
dataset: DatasetConfig = DatasetConfig()
is_ax: bool = field(
default=False,
metadata={
"help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
},
)
def reset_logging():
root = logging.getLogger()
for handler in root.handlers:
root.removeHandler(handler)
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
root.addHandler(handler)
class InferenceProcessor:
cfg: InferConfig
def __init__(self, cfg: InferConfig) -> None:
self.cfg = cfg
self.task = tasks.setup_task(cfg.task)
self.tgt_dict = self.task.target_dictionary
models, saved_cfg = self.load_model_ensemble()
self.models = models
self.saved_cfg = saved_cfg
self.task.load_dataset(
self.cfg.dataset.gen_subset,
task_cfg=saved_cfg.task,
)
self.generator = Decoder(cfg.decoding, self.tgt_dict)
self.gen_timer = StopwatchMeter()
self.wps_meter = TimeMeter()
self.num_sentences = 0
self.total_errors = 0
self.total_length = 0
self.hypo_words_file = None
self.hypo_units_file = None
self.ref_words_file = None
self.ref_units_file = None
self.progress_bar = self.build_progress_bar()
def __enter__(self) -> "InferenceProcessor":
if self.cfg.decoding.results_path is not None:
self.hypo_words_file = self.get_res_file("hypo.word")
self.hypo_units_file = self.get_res_file("hypo.units")
self.ref_words_file = self.get_res_file("ref.word")
self.ref_units_file = self.get_res_file("ref.units")
return self
def __exit__(self, *exc) -> bool:
if self.cfg.decoding.results_path is not None:
self.hypo_words_file.close()
self.hypo_units_file.close()
self.ref_words_file.close()
self.ref_units_file.close()
return False
def __iter__(self) -> Any:
for sample in self.progress_bar:
if not self.cfg.common.cpu:
sample = utils.move_to_cuda(sample)
# Happens on the last batch.
if "net_input" not in sample:
continue
yield sample
def log(self, *args, **kwargs):
self.progress_bar.log(*args, **kwargs)
def print(self, *args, **kwargs):
self.progress_bar.print(*args, **kwargs)
def get_res_file(self, fname: str) -> None:
fname = os.path.join(self.cfg.decoding.results_path, fname)
if self.data_parallel_world_size > 1:
fname = f"{fname}.{self.data_parallel_rank}"
return open(fname, "w", buffering=1)
def merge_shards(self) -> None:
"""Merges all shard files into shard 0, then removes shard suffix."""
shard_id = self.data_parallel_rank
num_shards = self.data_parallel_world_size
if self.data_parallel_world_size > 1:
def merge_shards_with_root(fname: str) -> None:
fname = os.path.join(self.cfg.decoding.results_path, fname)
logger.info("Merging %s on shard %d", fname, shard_id)
base_fpath = Path(f"{fname}.0")
with open(base_fpath, "a") as out_file:
for s in range(1, num_shards):
shard_fpath = Path(f"{fname}.{s}")
with open(shard_fpath, "r") as in_file:
for line in in_file:
out_file.write(line)
shard_fpath.unlink()
shutil.move(f"{fname}.0", fname)
dist.barrier() # ensure all shards finished writing
if shard_id == (0 % num_shards):
merge_shards_with_root("hypo.word")
if shard_id == (1 % num_shards):
merge_shards_with_root("hypo.units")
if shard_id == (2 % num_shards):
merge_shards_with_root("ref.word")
if shard_id == (3 % num_shards):
merge_shards_with_root("ref.units")
dist.barrier()
def optimize_model(self, model: FairseqModel) -> None:
model.make_generation_fast_()
if self.cfg.common.fp16:
model.half()
if not self.cfg.common.cpu:
model.cuda()
def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]:
arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(self.cfg.common_eval.path, separator="\\"),
arg_overrides=arg_overrides,
task=self.task,
suffix=self.cfg.checkpoint.checkpoint_suffix,
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=self.cfg.checkpoint.checkpoint_shard_count,
)
for model in models:
self.optimize_model(model)
return models, saved_cfg
def get_dataset_itr(self, disable_iterator_cache: bool = False) -> None:
return self.task.get_batch_iterator(
dataset=self.task.dataset(self.cfg.dataset.gen_subset),
max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
seed=self.cfg.common.seed,
num_shards=self.data_parallel_world_size,
shard_id=self.data_parallel_rank,
num_workers=self.cfg.dataset.num_workers,
data_buffer_size=self.cfg.dataset.data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
).next_epoch_itr(shuffle=False)
def build_progress_bar(
self,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
default_log_format: str = "tqdm",
) -> BaseProgressBar:
return progress_bar.progress_bar(
iterator=self.get_dataset_itr(),
log_format=self.cfg.common.log_format,
log_interval=self.cfg.common.log_interval,
epoch=epoch,
prefix=prefix,
tensorboard_logdir=self.cfg.common.tensorboard_logdir,
default_log_format=default_log_format,
)
@property
def data_parallel_world_size(self):
if self.cfg.distributed_training.distributed_world_size == 1:
return 1
return distributed_utils.get_data_parallel_world_size()
@property
def data_parallel_rank(self):
if self.cfg.distributed_training.distributed_world_size == 1:
return 0
return distributed_utils.get_data_parallel_rank()
def process_sentence(
self,
sample: Dict[str, Any],
hypo: Dict[str, Any],
sid: int,
batch_id: int,
) -> Tuple[int, int]:
speaker = None # Speaker can't be parsed from dataset.
if "target_label" in sample:
toks = sample["target_label"]
else:
toks = sample["target"]
toks = toks[batch_id, :]
# Processes hypothesis.
hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu())
if "words" in hypo:
hyp_words = " ".join(hypo["words"])
else:
hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process)
# Processes target.
target_tokens = utils.strip_pad(toks, self.tgt_dict.pad())
tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu())
tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process)
if self.cfg.decoding.results_path is not None:
print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file)
print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file)
print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file)
print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file)
if not self.cfg.common_eval.quiet:
logger.info(f"HYPO: {hyp_words}")
logger.info(f"REF: {tgt_words}")
logger.info("---------------------")
hyp_words, tgt_words = hyp_words.split(), tgt_words.split()
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def process_sample(self, sample: Dict[str, Any]) -> None:
self.gen_timer.start()
hypos = self.task.inference_step(
generator=self.generator,
models=self.models,
sample=sample,
)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
self.gen_timer.stop(num_generated_tokens)
self.wps_meter.update(num_generated_tokens)
for batch_id, sample_id in enumerate(sample["id"].tolist()):
errs, length = self.process_sentence(
sample=sample,
sid=sample_id,
batch_id=batch_id,
hypo=hypos[batch_id][0],
)
self.total_errors += errs
self.total_length += length
self.log({"wps": round(self.wps_meter.avg)})
if "nsentences" in sample:
self.num_sentences += sample["nsentences"]
else:
self.num_sentences += sample["id"].numel()
def log_generation_time(self) -> None:
logger.info(
"Processed %d sentences (%d tokens) in %.1fs %.2f "
"sentences per second, %.2f tokens per second)",
self.num_sentences,
self.gen_timer.n,
self.gen_timer.sum,
self.num_sentences / self.gen_timer.sum,
1.0 / self.gen_timer.avg,
)
def parse_wer(wer_file: Path) -> float:
with open(wer_file, "r") as f:
return float(f.readline().strip().split(" ")[1])
def get_wer_file(cfg: InferConfig) -> Path:
"""Hashes the decoding parameters to a unique file ID."""
base_path = "wer"
if cfg.decoding.results_path is not None:
base_path = os.path.join(cfg.decoding.results_path, base_path)
if cfg.decoding.unique_wer_file:
yaml_str = OmegaConf.to_yaml(cfg.decoding)
fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16)
return Path(f"{base_path}.{fid % 1000000}")
else:
return Path(base_path)
def main(cfg: InferConfig) -> float:
"""Entry point for main processing logic.
Args:
cfg: The inferance configuration to use.
wer: Optional shared memory pointer for returning the WER. If not None,
the final WER value will be written here instead of being returned.
Returns:
The final WER if `wer` is None, otherwise None.
"""
yaml_str, wer_file = OmegaConf.to_yaml(cfg.decoding), get_wer_file(cfg)
# Validates the provided configuration.
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.max_tokens = 4000000
if not cfg.common.cpu and not torch.cuda.is_available():
raise ValueError("CUDA not found; set `cpu=True` to run without CUDA")
with InferenceProcessor(cfg) as processor:
for sample in processor:
processor.process_sample(sample)
processor.log_generation_time()
if cfg.decoding.results_path is not None:
processor.merge_shards()
errs_t, leng_t = processor.total_errors, processor.total_length
if cfg.common.cpu:
logger.warning("Merging WER requires CUDA.")
elif processor.data_parallel_world_size > 1:
stats = torch.LongTensor([errs_t, leng_t]).cuda()
dist.all_reduce(stats, op=dist.ReduceOp.SUM)
errs_t, leng_t = stats[0].item(), stats[1].item()
wer = errs_t * 100.0 / leng_t
if distributed_utils.is_master(cfg.distributed_training):
with open(wer_file, "w") as f:
f.write(
(
f"WER: {wer}\n"
f"err / num_ref_words = {errs_t} / {leng_t}\n\n"
f"{yaml_str}"
)
)
return wer
@hydra.main(config_path=config_path, config_name="infer")
def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]:
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
cfg = OmegaConf.create(container)
OmegaConf.set_struct(cfg, True)
if cfg.common.reset_logging:
reset_logging()
# logger.info("Config:\n%s", OmegaConf.to_yaml(cfg))
wer = float("inf")
try:
if cfg.common.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(cfg, main)
else:
distributed_utils.call_main(cfg, main)
wer = parse_wer(get_wer_file(cfg))
except BaseException as e: # pylint: disable=broad-except
if not cfg.common.suppress_crashes:
raise
else:
logger.error("Crashed! %s", str(e))
logger.info("Word error rate: %.4f", wer)
if cfg.is_ax:
return wer, None
return wer
def cli_main() -> None:
try:
from hydra._internal.utils import (
get_args,
) # pylint: disable=import-outside-toplevel
cfg_name = get_args().config_name or "infer"
except ImportError:
logger.warning("Failed to get config name from hydra args")
cfg_name = "infer"
cs = ConfigStore.instance()
cs.store(name=cfg_name, node=InferConfig)
for k in InferConfig.__dataclass_fields__:
if is_dataclass(InferConfig.__dataclass_fields__[k].type):
v = InferConfig.__dataclass_fields__[k].default
cs.store(name=k, node=v)
hydra_main() # pylint: disable=no-value-for-parameter
if __name__ == "__main__":
cli_main()
import importlib
import os
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
task_name = file[: file.find(".py")]
importlib.import_module("examples.speech_recognition.tasks." + task_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import re
import sys
import torch
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
from fairseq.data import Dictionary
from fairseq.tasks import LegacyFairseqTask, register_task
def get_asr_dataset_from_json(data_json_path, tgt_dict):
"""
Parse data json and create dataset.
See scripts/asr_prep_json.py which pack json from raw files
Json example:
{
"utts": {
"4771-29403-0025": {
"input": {
"length_ms": 170,
"path": "/tmp/file1.flac"
},
"output": {
"text": "HELLO \n",
"token": "HE LLO",
"tokenid": "4815, 861"
}
},
"1564-142299-0096": {
...
}
}
"""
if not os.path.isfile(data_json_path):
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
with open(data_json_path, "rb") as f:
data_samples = json.load(f)["utts"]
assert len(data_samples) != 0
sorted_samples = sorted(
data_samples.items(),
key=lambda sample: int(sample[1]["input"]["length_ms"]),
reverse=True,
)
aud_paths = [s[1]["input"]["path"] for s in sorted_samples]
ids = [s[0] for s in sorted_samples]
speakers = []
for s in sorted_samples:
m = re.search("(.+?)-(.+?)-(.+?)", s[0])
speakers.append(m.group(1) + "_" + m.group(2))
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
tgt = [
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
for s in sorted_samples
]
# append eos
tgt = [[*t, tgt_dict.eos()] for t in tgt]
return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
@register_task("speech_recognition")
class SpeechRecognitionTask(LegacyFairseqTask):
"""
Task for training speech recognition model.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("data", help="path to data directory")
parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
)
parser.add_argument(
"--max-source-positions",
default=sys.maxsize,
type=int,
metavar="N",
help="max number of frames in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, tgt_dict):
super().__init__(args)
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries)."""
dict_path = os.path.join(args.data, "dict.txt")
if not os.path.isfile(dict_path):
raise FileNotFoundError("Dict not found: {}".format(dict_path))
tgt_dict = Dictionary.load(dict_path)
if args.criterion == "ctc_loss":
tgt_dict.add_symbol("<ctc_blank>")
elif args.criterion == "asg_loss":
for i in range(1, args.max_replabel + 1):
tgt_dict.add_symbol(replabel_symbol(i))
print("| dictionary: {} types".format(len(tgt_dict)))
return cls(args, tgt_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
def build_generator(self, models, args, **unused):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, self.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, self.target_dictionary)
elif w2l_decoder == "fairseqlm":
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
return W2lFairseqLMDecoder(args, self.target_dictionary)
else:
return super().build_generator(models, args)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.tgt_dict
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
return None
def max_positions(self):
"""Return the max speech and sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import re
from collections import deque
from enum import Enum
import numpy as np
"""
Utility modules for computation of Word Error Rate,
Alignments, as well as more granular metrics like
deletion, insersion and substitutions.
"""
class Code(Enum):
match = 1
substitution = 2
insertion = 3
deletion = 4
class Token(object):
def __init__(self, lbl="", st=np.nan, en=np.nan):
if np.isnan(st):
self.label, self.start, self.end = "", 0.0, 0.0
else:
self.label, self.start, self.end = lbl, st, en
class AlignmentResult(object):
def __init__(self, refs, hyps, codes, score):
self.refs = refs # std::deque<int>
self.hyps = hyps # std::deque<int>
self.codes = codes # std::deque<Code>
self.score = score # float
def coordinate_to_offset(row, col, ncols):
return int(row * ncols + col)
def offset_to_row(offset, ncols):
return int(offset / ncols)
def offset_to_col(offset, ncols):
return int(offset % ncols)
def trimWhitespace(str):
return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str)))
def str2toks(str):
pieces = trimWhitespace(str).split(" ")
toks = []
for p in pieces:
toks.append(Token(p, 0.0, 0.0))
return toks
class EditDistance(object):
def __init__(self, time_mediated):
self.time_mediated_ = time_mediated
self.scores_ = np.nan # Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>
self.backtraces_ = (
np.nan
) # Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic> backtraces_;
self.confusion_pairs_ = {}
def cost(self, ref, hyp, code):
if self.time_mediated_:
if code == Code.match:
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end)
elif code == Code.insertion:
return hyp.end - hyp.start
elif code == Code.deletion:
return ref.end - ref.start
else: # substitution
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1
else:
if code == Code.match:
return 0
elif code == Code.insertion or code == Code.deletion:
return 3
else: # substitution
return 4
def get_result(self, refs, hyps):
res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan)
num_rows, num_cols = self.scores_.shape
res.score = self.scores_[num_rows - 1, num_cols - 1]
curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
while curr_offset != 0:
curr_row = offset_to_row(curr_offset, num_cols)
curr_col = offset_to_col(curr_offset, num_cols)
prev_offset = self.backtraces_[curr_row, curr_col]
prev_row = offset_to_row(prev_offset, num_cols)
prev_col = offset_to_col(prev_offset, num_cols)
res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++
res.hyps.appendleft(curr_col - 1)
if curr_row - 1 == prev_row and curr_col == prev_col:
res.codes.appendleft(Code.deletion)
elif curr_row == prev_row and curr_col - 1 == prev_col:
res.codes.appendleft(Code.insertion)
else:
# assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
ref_str = refs[res.refs[0]].label
hyp_str = hyps[res.hyps[0]].label
if ref_str == hyp_str:
res.codes.appendleft(Code.match)
else:
res.codes.appendleft(Code.substitution)
confusion_pair = "%s -> %s" % (ref_str, hyp_str)
if confusion_pair not in self.confusion_pairs_:
self.confusion_pairs_[confusion_pair] = 1
else:
self.confusion_pairs_[confusion_pair] += 1
curr_offset = prev_offset
return res
def align(self, refs, hyps):
if len(refs) == 0 and len(hyps) == 0:
return np.nan
# NOTE: we're not resetting the values in these matrices because every value
# will be overridden in the loop below. If this assumption doesn't hold,
# be sure to set all entries in self.scores_ and self.backtraces_ to 0.
self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1))
self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1))
num_rows, num_cols = self.scores_.shape
for i in range(num_rows):
for j in range(num_cols):
if i == 0 and j == 0:
self.scores_[i, j] = 0.0
self.backtraces_[i, j] = 0
continue
if i == 0:
self.scores_[i, j] = self.scores_[i, j - 1] + self.cost(
None, hyps[j - 1], Code.insertion
)
self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols)
continue
if j == 0:
self.scores_[i, j] = self.scores_[i - 1, j] + self.cost(
refs[i - 1], None, Code.deletion
)
self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols)
continue
# Below here both i and j are greater than 0
ref = refs[i - 1]
hyp = hyps[j - 1]
best_score = self.scores_[i - 1, j - 1] + (
self.cost(ref, hyp, Code.match)
if (ref.label == hyp.label)
else self.cost(ref, hyp, Code.substitution)
)
prev_row = i - 1
prev_col = j - 1
ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion)
if ins < best_score:
best_score = ins
prev_row = i
prev_col = j - 1
delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion)
if delt < best_score:
best_score = delt
prev_row = i - 1
prev_col = j
self.scores_[i, j] = best_score
self.backtraces_[i, j] = coordinate_to_offset(
prev_row, prev_col, num_cols
)
return self.get_result(refs, hyps)
class WERTransformer(object):
def __init__(self, hyp_str, ref_str, verbose=True):
self.ed_ = EditDistance(False)
self.id2oracle_errs_ = {}
self.utts_ = 0
self.words_ = 0
self.insertions_ = 0
self.deletions_ = 0
self.substitutions_ = 0
self.process(["dummy_str", hyp_str, ref_str])
if verbose:
print("'%s' vs '%s'" % (hyp_str, ref_str))
self.report_result()
def process(self, input): # std::vector<std::string>&& input
if len(input) < 3:
print(
"Input must be of the form <id> ... <hypo> <ref> , got ",
len(input),
" inputs:",
)
return None
# Align
# std::vector<Token> hyps;
# std::vector<Token> refs;
hyps = str2toks(input[-2])
refs = str2toks(input[-1])
alignment = self.ed_.align(refs, hyps)
if alignment is None:
print("Alignment is null")
return np.nan
# Tally errors
ins = 0
dels = 0
subs = 0
for code in alignment.codes:
if code == Code.substitution:
subs += 1
elif code == Code.insertion:
ins += 1
elif code == Code.deletion:
dels += 1
# Output
row = input
row.append(str(len(refs)))
row.append(str(ins))
row.append(str(dels))
row.append(str(subs))
# print(row)
# Accumulate
kIdIndex = 0
kNBestSep = "/"
pieces = input[kIdIndex].split(kNBestSep)
if len(pieces) == 0:
print(
"Error splitting ",
input[kIdIndex],
" on '",
kNBestSep,
"', got empty list",
)
return np.nan
id = pieces[0]
if id not in self.id2oracle_errs_:
self.utts_ += 1
self.words_ += len(refs)
self.insertions_ += ins
self.deletions_ += dels
self.substitutions_ += subs
self.id2oracle_errs_[id] = [ins, dels, subs]
else:
curr_err = ins + dels + subs
prev_err = np.sum(self.id2oracle_errs_[id])
if curr_err < prev_err:
self.id2oracle_errs_[id] = [ins, dels, subs]
return 0
def report_result(self):
# print("---------- Summary ---------------")
if self.words_ == 0:
print("No words counted")
return
# 1-best
best_wer = (
100.0
* (self.insertions_ + self.deletions_ + self.substitutions_)
/ self.words_
)
print(
"\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, "
"%0.2f%% dels, %0.2f%% subs)"
% (
best_wer,
self.utts_,
self.words_,
100.0 * self.insertions_ / self.words_,
100.0 * self.deletions_ / self.words_,
100.0 * self.substitutions_ / self.words_,
)
)
def wer(self):
if self.words_ == 0:
wer = np.nan
else:
wer = (
100.0
* (self.insertions_ + self.deletions_ + self.substitutions_)
/ self.words_
)
return wer
def stats(self):
if self.words_ == 0:
stats = {}
else:
wer = (
100.0
* (self.insertions_ + self.deletions_ + self.substitutions_)
/ self.words_
)
stats = dict(
{
"wer": wer,
"utts": self.utts_,
"numwords": self.words_,
"ins": self.insertions_,
"dels": self.deletions_,
"subs": self.substitutions_,
"confusion_pairs": self.ed_.confusion_pairs_,
}
)
return stats
def calc_wer(hyp_str, ref_str):
t = WERTransformer(hyp_str, ref_str, verbose=0)
return t.wer()
def calc_wer_stats(hyp_str, ref_str):
t = WERTransformer(hyp_str, ref_str, verbose=0)
return t.stats()
def get_wer_alignment_codes(hyp_str, ref_str):
"""
INPUT: hypothesis string, reference string
OUTPUT: List of alignment codes (intermediate results from WER computation)
"""
t = WERTransformer(hyp_str, ref_str, verbose=0)
return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes
def merge_counts(x, y):
# Merge two hashes which have 'counts' as their values
# This can be used for example to merge confusion pair counts
# conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs'])
for k, v in y.items():
if k not in x:
x[k] = 0
x[k] += v
return x
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Flashlight decoders.
"""
import gc
import itertools as it
import os.path as osp
from typing import List
import warnings
from collections import deque, namedtuple
import numpy as np
import torch
from examples.speech_recognition.data.replabels import unpack_replabels
from fairseq import tasks
from fairseq.utils import apply_to_sample
from omegaconf import open_dict
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
try:
from flashlight.lib.text.dictionary import create_word_dict, load_words
from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from flashlight.lib.text.decoder import (
CriterionType,
LexiconDecoderOptions,
KenLM,
LM,
LMState,
SmearingMode,
Trie,
LexiconDecoder,
)
except:
warnings.warn(
"flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
)
LM = object
LMState = object
class W2lDecoder(object):
def __init__(self, args, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = args.nbest
# criterion-specific init
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
if "<sep>" in tgt_dict.indices:
self.silence = tgt_dict.index("<sep>")
elif "|" in tgt_dict.indices:
self.silence = tgt_dict.index("|")
else:
self.silence = tgt_dict.eos()
self.asg_transitions = None
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
model = models[0]
encoder_out = model(**encoder_input)
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out) # no need to normalize emissions
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]
class W2lKenLMDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.unit_lm = getattr(args, "unit_lm", False)
if args.lexicon:
self.lexicon = load_words(args.lexicon)
self.word_dict = create_word_dict(self.lexicon)
self.unk_word = self.word_dict.get_index("<unk>")
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.trie = Trie(self.vocab_size, self.silence)
start_state = self.lm.start(False)
for i, (word, spellings) in enumerate(self.lexicon.items()):
word_idx = self.word_dict.get_index(word)
_, score = self.lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = LexiconDecoderOptions(
beam_size=args.beam,
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
word_score=args.word_score,
unk_score=args.unk_weight,
sil_score=args.sil_weight,
log_add=False,
criterion_type=self.criterion_type,
)
if self.asg_transitions is None:
N = 768
# self.asg_transitions = torch.FloatTensor(N, N).zero_()
self.asg_transitions = []
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
self.asg_transitions,
self.unit_lm,
)
else:
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.decoder_opts = LexiconFreeDecoderOptions(
beam_size=args.beam,
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
sil_score=args.sil_weight,
log_add=False,
criterion_type=self.criterion_type,
)
self.decoder = LexiconFreeDecoder(
self.decoder_opts, self.lm, self.silence, self.blank, []
)
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.
Parameters
----------
token_idxs : List[int]
IDs of decoded tokens.
Returns
-------
List[int]
Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank:
continue
if i == 0 or token_idx != token_idxs[i-1]:
timesteps.append(i)
return timesteps
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append(
[
{
"tokens": self.get_tokens(result.tokens),
"score": result.score,
"timesteps": self.get_timesteps(result.tokens),
"words": [
self.word_dict.get_entry(x) for x in result.words if x >= 0
],
}
for result in nbest_results
]
)
return hypos
FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"])
class FairseqLM(LM):
def __init__(self, dictionary, model):
LM.__init__(self)
self.dictionary = dictionary
self.model = model
self.unk = self.dictionary.unk()
self.save_incremental = False # this currently does not work properly
self.max_cache = 20_000
model.cuda()
model.eval()
model.make_generation_fast_()
self.states = {}
self.stateq = deque()
def start(self, start_with_nothing):
state = LMState()
prefix = torch.LongTensor([[self.dictionary.eos()]])
incremental_state = {} if self.save_incremental else None
with torch.no_grad():
res = self.model(prefix.cuda(), incremental_state=incremental_state)
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
if incremental_state is not None:
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
self.states[state] = FairseqLMState(
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
)
self.stateq.append(state)
return state
def score(self, state: LMState, token_index: int, no_cache: bool = False):
"""
Evaluate language model based on the current lm state and new word
Parameters:
-----------
state: current lm state
token_index: index of the word
(can be lexicon index then you should store inside LM the
mapping between indices of lexicon and lm, or lm index of a word)
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
curr_state = self.states[state]
def trim_cache(targ_size):
while len(self.stateq) > targ_size:
rem_k = self.stateq.popleft()
rem_st = self.states[rem_k]
rem_st = FairseqLMState(rem_st.prefix, None, None)
self.states[rem_k] = rem_st
if curr_state.probs is None:
new_incremental_state = (
curr_state.incremental_state.copy()
if curr_state.incremental_state is not None
else None
)
with torch.no_grad():
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cuda(), new_incremental_state
)
elif self.save_incremental:
new_incremental_state = {}
res = self.model(
torch.from_numpy(curr_state.prefix).cuda(),
incremental_state=new_incremental_state,
)
probs = self.model.get_normalized_probs(
res, log_probs=True, sample=None
)
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cpu(), new_incremental_state
)
curr_state = FairseqLMState(
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
)
if not no_cache:
self.states[state] = curr_state
self.stateq.append(state)
score = curr_state.probs[token_index].item()
trim_cache(self.max_cache)
outstate = state.child(token_index)
if outstate not in self.states and not no_cache:
prefix = np.concatenate(
[curr_state.prefix, torch.LongTensor([[token_index]])], -1
)
incr_state = curr_state.incremental_state
self.states[outstate] = FairseqLMState(prefix, incr_state, None)
if token_index == self.unk:
score = float("-inf")
return outstate, score
def finish(self, state: LMState):
"""
Evaluate eos for language model based on the current lm state
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
return self.score(state, self.dictionary.eos())
def empty_cache(self):
self.states = {}
self.stateq = deque()
gc.collect()
class W2lFairseqLMDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.unit_lm = getattr(args, "unit_lm", False)
self.lexicon = load_words(args.lexicon) if args.lexicon else None
self.idx_to_wrd = {}
checkpoint = torch.load(args.kenlm_model, map_location="cpu")
if "cfg" in checkpoint and checkpoint["cfg"] is not None:
lm_args = checkpoint["cfg"]
else:
lm_args = convert_namespace_to_omegaconf(checkpoint["args"])
with open_dict(lm_args.task):
lm_args.task.data = osp.dirname(args.kenlm_model)
task = tasks.setup_task(lm_args.task)
model = task.build_model(lm_args.model)
model.load_state_dict(checkpoint["model"], strict=False)
self.trie = Trie(self.vocab_size, self.silence)
self.word_dict = task.dictionary
self.unk_word = self.word_dict.unk()
self.lm = FairseqLM(self.word_dict, model)
if self.lexicon:
start_state = self.lm.start(False)
for i, (word, spellings) in enumerate(self.lexicon.items()):
if self.unit_lm:
word_idx = i
self.idx_to_wrd[i] = word
score = 0
else:
word_idx = self.word_dict.index(word)
_, score = self.lm.score(start_state, word_idx, no_cache=True)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = LexiconDecoderOptions(
beam_size=args.beam,
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
word_score=args.word_score,
unk_score=args.unk_weight,
sil_score=args.sil_weight,
log_add=False,
criterion_type=self.criterion_type,
)
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
[],
self.unit_lm,
)
else:
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.decoder_opts = LexiconFreeDecoderOptions(
beam_size=args.beam,
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
sil_score=args.sil_weight,
log_add=False,
criterion_type=self.criterion_type,
)
self.decoder = LexiconFreeDecoder(
self.decoder_opts, self.lm, self.silence, self.blank, []
)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
def idx_to_word(idx):
if self.unit_lm:
return self.idx_to_wrd[idx]
else:
return self.word_dict[idx]
def make_hypo(result):
hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
if self.lexicon:
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
return hypo
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append([make_hypo(result) for result in nbest_results])
self.lm.empty_cache()
return hypos
# Speech-to-Text (S2T) Modeling
[https://www.aclweb.org/anthology/2020.aacl-demo.6](https://www.aclweb.org/anthology/2020.aacl-demo.6.pdf)
Speech recognition (ASR) and speech-to-text translation (ST) with fairseq.
## Data Preparation
S2T modeling data consists of source speech features, target text and other optional information
(source text, speaker id, etc.). Fairseq S2T uses per-dataset-split TSV manifest files
to store these information. Each data field is represented by a column in the TSV file.
Unlike text token embeddings, speech features (e.g. log mel-scale filter banks) are usually fixed
during model training and can be pre-computed. The manifest file contains the path to
either the feature file in NumPy format or the WAV/FLAC audio file. For the latter,
features will be extracted on-the-fly by fairseq S2T. Optionally, feature/audio files can be packed
into uncompressed ZIP files (then accessed via byte offset and length) to improve I/O performance.
Fairseq S2T also employs a YAML file for data related configurations: tokenizer type and dictionary path
for the target text, feature transforms such as CMVN (cepstral mean and variance normalization) and SpecAugment,
temperature-based resampling, etc.
## Model Training
Fairseq S2T uses the unified `fairseq-train` interface for model training. It requires arguments `--task speech_to_text`,
`--arch <model architecture in fairseq.models.speech_to_text.*>` and `--config-yaml <config YAML filename>`.
## Inference & Evaluation
Fairseq S2T uses the unified `fairseq-generate`/`fairseq-interactive` interface for inference and evaluation. It
requires arguments `--task speech_to_text` and `--config-yaml <config YAML filename>`. The interactive console takes
audio paths (one per line) as inputs.
## Examples
- [Speech Recognition (ASR) on LibriSpeech](docs/librispeech_example.md)
- [Speech-to-Text Translation (ST) on MuST-C](docs/mustc_example.md)
- [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md)
- [Speech-to-Text Translation (ST) on Multilingual TEDx](docs/mtedx_example.md)
- [Simultaneous Speech-to-Text Translation (SimulST) on MuST-C](docs/simulst_mustc_example.md)
## Updates
- 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples:
[ASR (LibriSpeech)](docs/librispeech_example.md#interactive-decoding)
and [ST (CoVoST 2)](docs/covost_example.md#interactive-decoding).
- 01/08/2021: Several fixes for S2T Transformer model, inference-time de-tokenization, scorer configuration and data
preparation scripts. We also add pre-trained models to the examples and revise the instructions.
Breaking changes: the data preparation scripts now extract filterbank features without CMVN. CMVN is instead applied
on-the-fly (defined in the config YAML).
## What's Next
- We are migrating the old fairseq [ASR example](../speech_recognition) into this S2T framework and
merging the features from both sides.
- The following papers also base their experiments on fairseq S2T. We are adding more examples for replication.
- [Improving Cross-Lingual Transfer Learning for End-to-End Speech Recognition with Speech Translation (Wang et al., 2020)](https://arxiv.org/abs/2006.05474)
- [Self-Supervised Representations Improve End-to-End Speech Translation (Wu et al., 2020)](https://arxiv.org/abs/2006.12124)
- [Self-Training for End-to-End Speech Translation (Pino et al., 2020)](https://arxiv.org/abs/2006.02490)
- [CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus (Wang et al., 2020)](https://arxiv.org/abs/2002.01320)
- [Harnessing Indirect Training Data for End-to-End Automatic Speech Translation: Tricks of the Trade (Pino et al., 2019)](https://arxiv.org/abs/1909.06515)
## Citation
Please cite as:
```
@inproceedings{wang2020fairseqs2t,
title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
year = {2020},
}
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
year = {2019},
}
```
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
from pathlib import Path
import zipfile
from functools import reduce
from multiprocessing import cpu_count
from typing import Any, Dict, List, Optional, Union
import numpy as np
import pandas as pd
import sentencepiece as sp
from fairseq.data.audio.audio_utils import (
_convert_to_mono, _get_kaldi_fbank, _get_torchaudio_fbank
)
import torch
from tqdm import tqdm
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
def gen_vocab(
input_path: Path, output_path_prefix: Path, model_type="bpe",
vocab_size=1000, special_symbols: Optional[List[str]] = None
):
# Train SentencePiece Model
arguments = [
f"--input={input_path.as_posix()}",
f"--model_prefix={output_path_prefix.as_posix()}",
f"--model_type={model_type}",
f"--vocab_size={vocab_size}",
"--character_coverage=1.0",
f"--num_threads={cpu_count()}",
f"--unk_id={UNK_TOKEN_ID}",
f"--bos_id={BOS_TOKEN_ID}",
f"--eos_id={EOS_TOKEN_ID}",
f"--pad_id={PAD_TOKEN_ID}",
]
if special_symbols is not None:
_special_symbols = ",".join(special_symbols)
arguments.append(f"--user_defined_symbols={_special_symbols}")
sp.SentencePieceTrainer.Train(" ".join(arguments))
# Export fairseq dictionary
spm = sp.SentencePieceProcessor()
spm.Load(output_path_prefix.as_posix() + ".model")
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
assert (
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
)
vocab = {
i: s
for i, s in vocab.items()
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
}
with open(output_path_prefix.as_posix() + ".txt", "w") as f_out:
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
f_out.write(f"{s} 1\n")
def extract_fbank_features(
waveform: torch.FloatTensor,
sample_rate: int,
output_path: Optional[Path] = None,
n_mel_bins: int = 80,
overwrite: bool = False,
):
if output_path is not None and output_path.is_file() and not overwrite:
return
_waveform = _convert_to_mono(waveform, sample_rate)
_waveform = _waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers
_waveform = _waveform.numpy()
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable fbank feature extraction"
)
if output_path is not None:
np.save(output_path.as_posix(), features)
else:
return features
def create_zip(data_root: Path, zip_path: Path):
paths = list(data_root.glob("*.npy"))
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
for path in tqdm(paths):
f.write(path, arcname=path.name)
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def get_zip_manifest(zip_path: Path, zip_root: Optional[Path] = None):
_zip_path = zip_path if zip_root is None else Path.joinpath(zip_root, zip_path)
with zipfile.ZipFile(_zip_path, mode="r") as f:
info = f.infolist()
manifest = {}
for i in tqdm(info):
utt_id = Path(i.filename).stem
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
manifest[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}"
with open(_zip_path, "rb") as f:
f.seek(offset)
data = f.read(file_size)
assert len(data) > 1 and is_npy_data(data)
return manifest
def gen_config_yaml(
manifest_root: Path,
spm_filename: str,
yaml_filename: str = "config.yaml",
specaugment_policy: str = "lb",
prepend_tgt_lang_tag: bool = False,
sampling_alpha: float = 1.0,
audio_root: str = "",
cmvn_type: str = "utterance",
gcmvn_path: Optional[Path] = None,
):
manifest_root = manifest_root.absolute()
writer = S2TDataConfigWriter(manifest_root / yaml_filename)
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
writer.set_input_channels(1)
writer.set_input_feat_per_channel(80)
specaugment_setters = {
"lb": writer.set_specaugment_lb_policy,
"ld": writer.set_specaugment_ld_policy,
"sm": writer.set_specaugment_sm_policy,
"ss": writer.set_specaugment_ss_policy,
}
specaugment_setter = specaugment_setters.get(specaugment_policy, None)
if specaugment_setter is not None:
specaugment_setter()
writer.set_bpe_tokenizer(
{
"bpe": "sentencepiece",
"sentencepiece_model": (manifest_root / spm_filename).as_posix(),
}
)
if prepend_tgt_lang_tag:
writer.set_prepend_tgt_lang_tag(True)
writer.set_sampling_alpha(sampling_alpha)
if cmvn_type not in ["global", "utterance"]:
raise NotImplementedError
writer.set_feature_transforms("_train", [f"{cmvn_type}_cmvn", "specaugment"])
writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"])
if cmvn_type == "global":
assert gcmvn_path is not None, (
'Please provide path of global cmvn file.'
)
writer.set_global_cmvn(str(gcmvn_path))
if len(audio_root) > 0:
writer.set_audio_root(audio_root)
writer.flush()
def load_df_from_tsv(path: Union[str, Path]):
_path = path if isinstance(path, str) else path.as_posix()
return pd.read_csv(
_path,
sep="\t",
header=0,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
na_filter=False,
)
def save_df_to_tsv(dataframe, path: Union[str, Path]):
_path = path if isinstance(path, str) else path.as_posix()
dataframe.to_csv(
_path,
sep="\t",
header=True,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
)
def filter_manifest_df(
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
):
filters = {
"no speech": df["audio"] == "",
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
"empty sentence": df["tgt_text"] == "",
}
if is_train_split:
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
if extra_filters is not None:
filters.update(extra_filters)
invalid = reduce(lambda x, y: x | y, filters.values())
valid = ~invalid
print(
"| "
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
)
return df[valid]
def cal_gcmvn_stats(features_list):
features = np.concatenate(features_list)
square_sums = (features ** 2).sum(axis=0)
mean = features.mean(axis=0)
features = np.subtract(features, mean)
var = square_sums / features.shape[0] - mean ** 2
std = np.sqrt(np.maximum(var, 1e-8))
return {"mean": mean.astype("float32"), "std": std.astype("float32")}
class S2TDataConfigWriter(object):
DEFAULT_VOCAB_FILENAME = "dict.txt"
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
DEFAULT_INPUT_CHANNELS = 1
def __init__(self, yaml_path: Path):
try:
import yaml
except ImportError:
print("Please install PyYAML for S2T data config YAML files")
self.yaml = yaml
self.yaml_path = yaml_path
self.config = {}
def flush(self):
with open(self.yaml_path, "w") as f:
self.yaml.dump(self.config, f)
def set_audio_root(self, audio_root=""):
self.config["audio_root"] = audio_root
def set_vocab_filename(self, vocab_filename: str = "dict.txt"):
self.config["vocab_filename"] = vocab_filename
def set_specaugment(
self,
time_wrap_w: int,
freq_mask_n: int,
freq_mask_f: int,
time_mask_n: int,
time_mask_t: int,
time_mask_p: float,
):
self.config["specaugment"] = {
"time_wrap_W": time_wrap_w,
"freq_mask_N": freq_mask_n,
"freq_mask_F": freq_mask_f,
"time_mask_N": time_mask_n,
"time_mask_T": time_mask_t,
"time_mask_p": time_mask_p,
}
def set_specaugment_lb_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=1,
freq_mask_f=27,
time_mask_n=1,
time_mask_t=100,
time_mask_p=1.0,
)
def set_specaugment_ld_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=27,
time_mask_n=2,
time_mask_t=100,
time_mask_p=1.0,
)
def set_specaugment_sm_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=15,
time_mask_n=2,
time_mask_t=70,
time_mask_p=0.2,
)
def set_specaugment_ss_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=27,
time_mask_n=2,
time_mask_t=70,
time_mask_p=0.2,
)
def set_input_channels(self, input_channels: int = 1):
self.config["input_channels"] = input_channels
def set_input_feat_per_channel(self, input_feat_per_channel: int = 80):
self.config["input_feat_per_channel"] = input_feat_per_channel
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
self.config["bpe_tokenizer"] = bpe_tokenizer
def set_global_cmvn(self, stats_npz_path: str):
self.config["global_cmvn"] = {"stats_npz_path": stats_npz_path}
def set_feature_transforms(self, split: str, transforms: List[str]):
if "transforms" not in self.config:
self.config["transforms"] = {}
self.config["transforms"][split] = transforms
def set_prepend_tgt_lang_tag(self, flag: bool = True):
self.config["prepend_tgt_lang_tag"] = flag
def set_sampling_alpha(self, sampling_alpha: float = 1.0):
self.config["sampling_alpha"] = sampling_alpha
[[Back]](..)
# S2T Example: ST on CoVoST
We replicate the experiments in
[CoVoST 2 and Massively Multilingual Speech-to-Text Translation (Wang et al., 2020)](https://arxiv.org/abs/2007.10310).
## Data Preparation
[Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path
`${COVOST_ROOT}/${SOURCE_LANG_ID}`, then preprocess it with
```bash
# additional Python packages for S2T data processing/model training
pip install pandas torchaudio sentencepiece
# En ASR
python examples/speech_to_text/prep_covost_data.py \
--data-root ${COVOST_ROOT} --vocab-type char --src-lang en
# ST
python examples/speech_to_text/prep_covost_data.py \
--data-root ${COVOST_ROOT} --vocab-type char \
--src-lang fr --tgt-lang en
```
The generated files (manifest, features, vocabulary and data configuration) will be added to
`${COVOST_ROOT}/${SOURCE_LANG_ID}`.
Download our vocabulary files if you want to use our pre-trained models:
- ASR: [En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_vocab_char.zip)
- ST: [Fr-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_vocab_char.zip), [De-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_vocab_char.zip), [Es-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_vocab_char.zip), [Ca-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_vocab_char.zip), [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip), [En-Ca](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_vocab_char.zip), [En-Fa](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_vocab_char.zip), [En-Et](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_vocab_char.zip)
## ASR
#### Training
We train an En ASR model for encoder pre-training of all ST models:
```bash
fairseq-train ${COVOST_ROOT}/en \
--config-yaml config_asr_en.yaml --train-subset train_asr_en --valid-subset dev_asr_en \
--save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 50000 --max-update 60000 \
--task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--report-accuracy --arch s2t_transformer_s --dropout 0.15 --optimizer adam --lr 2e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
```
where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
#### Inference & Evaluation
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${COVOST_ROOT}/en \
--config-yaml config_asr_en.yaml --gen-subset test_asr_en --task speech_to_text \
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
```
#### Results
| --arch | Params | En | Model |
|---|---|---|---|
| s2t_transformer_s | 31M | 25.6 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_transformer_s.pt) |
## ST
#### Training
Fr-En as example:
```bash
fairseq-train ${COVOST_ROOT}/fr \
--config-yaml config_st_fr_en.yaml --train-subset train_st_fr_en --valid-subset dev_st_fr_en \
--save-dir ${ST_SAVE_DIR} --num-workers 4 --max-update 30000 --max-tokens 40000 \ # --max-tokens 50000 for en-*
--task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
--arch s2t_transformer_s --encoder-freezing-updates 1000 --optimizer adam --lr 2e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
```
where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by En ASR for faster training and better
performance: `--load-pretrained-encoder-from <ASR checkpoint path>`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
#### Inference & Evaluation
Average the last 10 checkpoints and evaluate on test split:
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${COVOST_ROOT}/fr \
--config-yaml config_st_fr_en.yaml --gen-subset test_st_fr_en --task speech_to_text \
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
--max-tokens 50000 --beam 5 --scoring sacrebleu
```
## Interactive Decoding
Launch the interactive console via
```bash
fairseq-interactive ${COVOST_ROOT}/fr --config-yaml config_st_fr_en.yaml \
--task speech_to_text --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \
--max-tokens 50000 --beam 5
```
Type in WAV/FLAC/OGG audio paths (one per line) after the prompt.
#### Results
| --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et | Model |
|---|---|---|---|---|---|---|---|---|---|---|
| s2t_transformer_s | 31M | [27.2](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_transformer_s.pt) | [23.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_transformer_s.pt) | [19.3](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_transformer_s.pt) | [16.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_transformer_s.pt) | [21.6](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_transformer_s.pt) | [12.9](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_transformer_s.pt) | [12.8](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_transformer_s.pt) | (<-Download) |
[[Back]](..)
[[Back]](..)
# S2T Example: Speech Recognition (ASR) on LibriSpeech
[LibriSpeech](https://www.danielpovey.com/files/2015_icassp_librispeech.pdf) is a de-facto standard English ASR
benchmark. We provide competitive
vanilla [Transformer](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) baselines.
## Data preparation
Download and preprocess LibriSpeech data with
```bash
# additional Python packages for S2T data processing/model training
pip install pandas torchaudio sentencepiece
python examples/speech_to_text/prep_librispeech_data.py \
--output-root ${LS_ROOT} --vocab-type unigram --vocab-size 10000
```
where `LS_ROOT` is the root path for downloaded data as well as generated files (manifest, features, vocabulary and
data configuration).
[Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_vocab_unigram10000.zip) our vocabulary files
if you want to use our pre-trained models.
## Training
```bash
fairseq-train ${LS_ROOT} --save-dir ${SAVE_DIR} \
--config-yaml config.yaml --train-subset train-clean-100,train-clean-360,train-other-500 --valid-subset dev-clean,dev-other \
--num-workers 4 --max-tokens 40000 --max-update 300000 \
--task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
--arch s2t_transformer_s --share-decoder-input-output-embed \
--optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 \
--clip-norm 10.0 --seed 1 --update-freq 8
```
where `SAVE_DIR` is the checkpoint root path. Here we use `--arch s2t_transformer_s` (31M parameters) as example.
For better performance, you may switch to `s2t_transformer_m` (71M, with `--lr 1e-3`) or `s2t_transformer_l`
(268M, with `--lr 5e-4`). We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly
when using more than 1 GPU.
## Inference & Evaluation
Average the last 10 checkpoints and evaluate on the 4 splits
(`dev-clean`, `dev-other`, `test-clean` and `test-other`):
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py --inputs ${SAVE_DIR} \
--num-epoch-checkpoints 10 \
--output "${SAVE_DIR}/${CHECKPOINT_FILENAME}"
for SUBSET in dev-clean dev-other test-clean test-other; do
fairseq-generate ${LS_ROOT} --config-yaml config.yaml --gen-subset ${SUBSET} \
--task speech_to_text --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \
--max-tokens 50000 --beam 5 --scoring wer
done
```
## Interactive Decoding
Launch the interactive console via
```bash
fairseq-interactive ${LS_ROOT} --config-yaml config.yaml --task speech_to_text \
--path ${SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5
```
Type in WAV/FLAC/OGG audio paths (one per line) after the prompt.
## Results
| --arch | Params | dev-clean | dev-other | test-clean | test-other | Model |
|---|---|---|---|---|---|---|
| s2t_transformer_s | 30M | 3.8 | 8.9 | 4.4 | 9.0 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_transformer_s.pt) |
| s2t_transformer_m | 71M | 3.2 | 8.0 | 3.4 | 7.9 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_transformer_m.pt) |
| s2t_transformer_l | 268M | 3.0 | 7.5 | 3.2 | 7.5 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_transformer_l.pt) |
[[Back]](..)
[[Back]](..)
# S2T Example: Speech Translation (ST) on Multilingual TEDx
[Multilingual TEDx](https://arxiv.org/abs/2102.01757) is multilingual corpus for speech recognition and
speech translation. The data is derived from TEDx talks in 8 source languages
with translations to a subset of 5 target languages.
## Data Preparation
[Download](http://openslr.org/100/) and unpack Multilingual TEDx data to a path
`${MTEDX_ROOT}/${LANG_PAIR}`, then preprocess it with
```bash
# additional Python packages for S2T data processing/model training
pip install pandas torchaudio soundfile sentencepiece
# Generate TSV manifests, features, vocabulary
# and configuration for each language
python examples/speech_to_text/prep_mtedx_data.py \
--data-root ${MTEDX_ROOT} --task asr \
--vocab-type unigram --vocab-size 1000
python examples/speech_to_text/prep_mtedx_data.py \
--data-root ${MTEDX_ROOT} --task st \
--vocab-type unigram --vocab-size 1000
# Add vocabulary and configuration for joint data
# (based on the manifests and features generated above)
python examples/speech_to_text/prep_mtedx_data.py \
--data-root ${MTEDX_ROOT} --task asr --joint \
--vocab-type unigram --vocab-size 8000
python examples/speech_to_text/prep_mtedx_data.py \
--data-root ${MTEDX_ROOT} --task st --joint \
--vocab-type unigram --vocab-size 8000
```
The generated files (manifest, features, vocabulary and data configuration) will be added to
`${MTEDX_ROOT}/${LANG_PAIR}` (per-language data) and `MTEDX_ROOT` (joint data).
## ASR
#### Training
Spanish as example:
```bash
fairseq-train ${MTEDX_ROOT}/es-es \
--config-yaml config_asr.yaml --train-subset train_asr --valid-subset valid_asr \
--save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \
--task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
--arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \
--load-pretrained-encoder-from ${PRETRAINED_ENCODER} \
--skip-invalid-size-inputs-valid-test \
--keep-last-epochs 10 --update-freq 8 --patience 10
```
For joint model (using ASR data from all 8 languages):
```bash
fairseq-train ${MTEDX_ROOT} \
--config-yaml config_asr.yaml \
--train-subset train_es-es_asr,train_fr-fr_asr,train_pt-pt_asr,train_it-it_asr,train_ru-ru_asr,train_el-el_asr,train_ar-ar_asr,train_de-de_asr \
--valid-subset valid_es-es_asr,valid_fr-fr_asr,valid_pt-pt_asr,valid_it-it_asr,valid_ru-ru_asr,valid_el-el_asr,valid_ar-ar_asr,valid_de-de_asr \
--save-dir ${MULTILINGUAL_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \
--task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
--arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \
--skip-invalid-size-inputs-valid-test \
--keep-last-epochs 10 --update-freq 8 --patience 10 \
--ignore-prefix-size 1
```
where `MULTILINGUAL_ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs
with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
For multilingual models, we prepend target language ID token as target BOS, which should be excluded from
the training loss via `--ignore-prefix-size 1`.
#### Inference & Evaluation
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${MTEDX_ROOT}/es-es \
--config-yaml config_asr.yaml --gen-subset test --task speech_to_text \
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
--skip-invalid-size-inputs-valid-test \
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe
# For models trained on joint data
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${MULTILINGUAL_ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
for LANG in es fr pt it ru el ar de; do
fairseq-generate ${MTEDX_ROOT} \
--config-yaml config_asr.yaml --gen-subset test_${LANG}-${LANG}_asr --task speech_to_text \
--prefix-size 1 --path ${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \
--max-tokens 40000 --beam 5 \
--skip-invalid-size-inputs-valid-test \
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe
done
```
#### Results
| Data | --arch | Params | Es | Fr | Pt | It | Ru | El | Ar | De |
|--------------|--------------------|--------|------|------|------|------|------|-------|-------|-------|
| Monolingual | s2t_transformer_xs | 10M | 46.4 | 45.6 | 54.8 | 48.0 | 74.7 | 109.5 | 104.4 | 111.1 |
## ST
#### Training
Es-En as example:
```bash
fairseq-train ${MTEDX_ROOT}/es-en \
--config-yaml config_st.yaml --train-subset train_st --valid-subset valid_st \
--save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \
--task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
--arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \
--load-pretrained-encoder-from ${PRETRAINED_ENCODER} \
--skip-invalid-size-inputs-valid-test \
--keep-last-epochs 10 --update-freq 8 --patience 10
```
For multilingual model (all 12 directions):
```bash
fairseq-train ${MTEDX_ROOT} \
--config-yaml config_st.yaml \
--train-subset train_el-en_st,train_es-en_st,train_es-fr_st,train_es-it_st,train_es-pt_st,train_fr-en_st,train_fr-es_st,train_fr-pt_st,train_it-en_st,train_it-es_st,train_pt-en_st,train_pt-es_st,train_ru-en_st \
--valid-subset valid_el-en_st,valid_es-en_st,valid_es-fr_st,valid_es-it_st,valid_es-pt_st,valid_fr-en_st,valid_fr-es_st,valid_fr-pt_st,valid_it-en_st,valid_it-es_st,valid_pt-en_st,valid_pt-es_st,valid_ru-en_st \
--save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \
--task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
--arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \
--skip-invalid-size-inputs-valid-test \
--keep-last-epochs 10 --update-freq 8 --patience 10 \
--ignore-prefix-size 1 \
--load-pretrained-encoder-from ${PRETRAINED_ENCODER}
```
where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR
for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set
`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
For multilingual models, we prepend target language ID token as target BOS, which should be excluded from
the training loss via `--ignore-prefix-size 1`.
#### Inference & Evaluation
Average the last 10 checkpoints and evaluate on the `test` split:
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${MTEDX_ROOT}/es-en \
--config-yaml config_st.yaml --gen-subset test --task speech_to_text \
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
--max-tokens 50000 --beam 5 --scoring sacrebleu --remove-bpe
# For multilingual models
python scripts/average_checkpoints.py \
--inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
for LANGPAIR in es-en es-fr es-pt fr-en fr-es fr-pt pt-en pt-es it-en it-es ru-en el-en; do
fairseq-generate ${MTEDX_ROOT} \
--config-yaml config_st.yaml --gen-subset test_${LANGPAIR}_st --task speech_to_text \
--prefix-size 1 --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
--max-tokens 40000 --beam 5 \
--skip-invalid-size-inputs-valid-test \
--scoring sacrebleu --remove-bpe
done
```
For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`.
#### Results
| Data | --arch | Params | Es-En | Es-Pt | Es-Fr | Fr-En | Fr-Es | Fr-Pt | Pt-En | Pt-Es | It-En | It-Es | Ru-En | El-En |
|--------------|--------------------|-----|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|
| Bilingual | s2t_transformer_xs | 10M | 7.0 | 12.2 | 1.7 | 8.9 | 10.6 | 7.9 | 8.1 | 8.7 | 6.4 | 1.0 | 0.7 | 0.6 |
| Multilingual | s2t_transformer_s | 31M | 12.3 | 17.4 | 6.1 | 12.0 | 13.6 | 13.2 | 12.0 | 13.7 | 10.7 | 13.1 | 0.6 | 0.8 |
## Citation
Please cite as:
```
@misc{salesky2021mtedx,
title={Multilingual TEDx Corpus for Speech Recognition and Translation},
author={Elizabeth Salesky and Matthew Wiesner and Jacob Bremerman and Roldano Cattoni and Matteo Negri and Marco Turchi and Douglas W. Oard and Matt Post},
year={2021},
}
@inproceedings{wang2020fairseqs2t,
title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
year = {2020},
}
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
year = {2019},
}
```
[[Back]](..)
[[Back]](..)
# S2T Example: Speech Translation (ST) on MuST-C
[MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with
8-language translations on English TED talks. We match the state-of-the-art performance in
[ESPNet-ST](https://arxiv.org/pdf/2004.10234.pdf) with a simpler model training pipeline.
## Data Preparation
[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path
`${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with
```bash
# additional Python packages for S2T data processing/model training
pip install pandas torchaudio soundfile sentencepiece
# Generate TSV manifests, features, vocabulary
# and configuration for each language
python examples/speech_to_text/prep_mustc_data.py \
--data-root ${MUSTC_ROOT} --task asr \
--vocab-type unigram --vocab-size 5000
python examples/speech_to_text/prep_mustc_data.py \
--data-root ${MUSTC_ROOT} --task st \
--vocab-type unigram --vocab-size 8000
# Add vocabulary and configuration for joint data
# (based on the manifests and features generated above)
python examples/speech_to_text/prep_mustc_data.py \
--data-root ${MUSTC_ROOT} --task asr --joint \
--vocab-type unigram --vocab-size 10000
python examples/speech_to_text/prep_mustc_data.py \
--data-root ${MUSTC_ROOT} --task st --joint \
--vocab-type unigram --vocab-size 10000
```
The generated files (manifest, features, vocabulary and data configuration) will be added to
`${MUSTC_ROOT}/en-${TARGET_LANG_ID}` (per-language data) and `MUSTC_ROOT` (joint data).
Download our vocabulary files if you want to use our pre-trained models:
- ASR: [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_vocab_unigram5000.zip), [En-Nl](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_asr_vocab_unigram5000.zip), [En-Es](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_vocab_unigram5000.zip), [En-Fr](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_asr_vocab_unigram5000.zip), [En-It](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_asr_vocab_unigram5000.zip), [En-Pt](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_asr_vocab_unigram5000.zip), [En-Ro](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_asr_vocab_unigram5000.zip), [En-Ru](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_asr_vocab_unigram5000.zip), [Joint](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_vocab_unigram10000.zip)
- ST: [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_vocab_unigram8000.zip), [En-Nl](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_st_vocab_unigram8000.zip), [En-Es](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_st_vocab_unigram8000.zip), [En-Fr](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_st_vocab_unigram8000.zip), [En-It](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_st_vocab_unigram8000.zip), [En-Pt](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_st_vocab_unigram8000.zip), [En-Ro](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_st_vocab_unigram8000.zip), [En-Ru](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_st_vocab_unigram8000.zip), [Multilingual](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_st_vocab_unigram10000.zip)
## ASR
#### Training
En-De as example:
```bash
fairseq-train ${MUSTC_ROOT}/en-de \
--config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \
--save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
--task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
--arch s2t_transformer_s --optimizer adam --lr 1e-3 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
```
For joint model (using ASR data from all 8 directions):
```bash
fairseq-train ${MUSTC_ROOT} \
--config-yaml config_asr.yaml \
--train-subset train_de_asr,train_nl_asr,train_es_asr,train_fr_asr,train_it_asr,train_pt_asr,train_ro_asr,train_ru_asr \
--valid-subset dev_de_asr,dev_nl_asr,dev_es_asr,dev_fr_asr,dev_it_asr,dev_pt_asr,dev_ro_asr,dev_ru_asr \
--save-dir ${JOINT_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
--task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
--arch s2t_transformer_s --optimizer adam --lr 1e-3 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
```
where `ASR_SAVE_DIR` (`JOINT_ASR_SAVE_DIR`) is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs
with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
#### Inference & Evaluation
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${MUSTC_ROOT}/en-de \
--config-yaml config_asr.yaml --gen-subset tst-COMMON_asr --task speech_to_text \
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
# For models trained on joint data
python scripts/average_checkpoints.py \
--inputs ${JOINT_ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
for LANG in de nl es fr it pt ro ru; do
fairseq-generate ${MUSTC_ROOT} \
--config-yaml config_asr.yaml --gen-subset tst-COMMON_${LANG}_asr --task speech_to_text \
--path ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
done
```
#### Results
| Data | --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | Model |
|---|---|---|---|---|---|---|---|---|---|---|---|
| Single | s2t_transformer_s | 31M | [18.2](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_transformer_s.pt) | [17.6](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_asr_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_transformer_s.pt) | [17.2](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_asr_transformer_s.pt) | [17.9](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_asr_transformer_s.pt) | [19.1](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_asr_transformer_s.pt) | [18.1](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_asr_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_asr_transformer_s.pt) | (<-Download) |
| Joint | s2t_transformer_m | 76M | 16.8 | 16.7 | 16.9 | 16.9 | 17.0 | 17.4 | 17.0 | 16.9 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_transformer_m.pt) |
## ST
#### Training
En-De as example:
```bash
fairseq-train ${MUSTC_ROOT}/en-de \
--config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \
--save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
--task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
--arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
```
For multilingual model (all 8 directions):
```bash
fairseq-train ${MUSTC_ROOT} \
--config-yaml config_st.yaml \
--train-subset train_de_st,train_nl_st,train_es_st,train_fr_st,train_it_st,train_pt_st,train_ro_st,train_ru_st \
--valid-subset dev_de_st,dev_nl_st,dev_es_st,dev_fr_st,dev_it_st,dev_pt_st,dev_ro_st,dev_ru_st \
--save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
--task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
--arch s2t_transformer_s --ignore-prefix-size 1 --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
--load-pretrained-encoder-from ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
```
where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR
for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set
`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
For multilingual models, we prepend target language ID token as target BOS, which should be excluded from
the training loss via `--ignore-prefix-size 1`.
#### Inference & Evaluation
Average the last 10 checkpoints and evaluate on the `tst-COMMON` split:
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${MUSTC_ROOT}/en-de \
--config-yaml config_st.yaml --gen-subset tst-COMMON_st --task speech_to_text \
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
--max-tokens 50000 --beam 5 --scoring sacrebleu
# For multilingual models
python scripts/average_checkpoints.py \
--inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \
--output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
for LANG in de nl es fr it pt ro ru; do
fairseq-generate ${MUSTC_ROOT} \
--config-yaml config_st.yaml --gen-subset tst-COMMON_${LANG}_st --task speech_to_text \
--prefix-size 1 --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
--max-tokens 50000 --beam 5 --scoring sacrebleu
done
```
For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`.
#### Results
| Data | --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | Model |
|---|---|---|---|---|---|---|---|---|---|---|---|
| Bilingual | s2t_transformer_s | 31M | [22.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_transformer_s.pt) | [27.3](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_st_transformer_s.pt) | [27.2](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_st_transformer_s.pt) | [32.9](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_st_transformer_s.pt) | [22.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_st_transformer_s.pt) | [28.1](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_st_transformer_s.pt) | [21.9](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_st_transformer_s.pt) | [15.3](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_st_transformer_s.pt) | (<-Download) |
| Multilingual | s2t_transformer_m | 76M | 24.5 | 28.6 | 28.2 | 34.9 | 24.6 | 31.1 | 23.8 | 16.0 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_st_transformer_m.pt) |
[[Back]](..)
# Simultaneous Speech Translation (SimulST) on MuST-C
This is a tutorial of training and evaluating a transformer *wait-k* simultaneous model on MUST-C English-Germen Dataset, from [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58.pdf).
[MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks.
## Data Preparation
This section introduces the data preparation for training and evaluation.
If you only want to evaluate the model, please jump to [Inference & Evaluation](#inference-&-evaluation)
[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path
`${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with
```bash
# Additional Python packages for S2T data processing/model training
pip install pandas torchaudio sentencepiece
# Generate TSV manifests, features, vocabulary,
# global cepstral and mean estimation,
# and configuration for each language
cd fairseq
python examples/speech_to_text/prep_mustc_data.py \
--data-root ${MUSTC_ROOT} --task asr \
--vocab-type unigram --vocab-size 10000 \
--cmvn-type global
python examples/speech_to_text/prep_mustc_data.py \
--data-root ${MUSTC_ROOT} --task st \
--vocab-type unigram --vocab-size 10000 \
--cmvn-type global
```
## ASR Pretraining
We need a pretrained offline ASR model. Assuming the save directory of the ASR model is `${ASR_SAVE_DIR}`.
The following command (and the subsequent training commands in this tutorial) assume training on 1 GPU (you can also train on 8 GPUs and remove the `--update-freq 8` option).
```
fairseq-train ${MUSTC_ROOT}/en-de \
--config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \
--save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
--task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
--arch convtransformer_espnet --optimizer adam --lr 0.0005 --lr-scheduler inverse_sqrt \
--warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
```
A pretrained ASR checkpoint can be downloaded [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1_en_de_pretrained_asr)
## Simultaneous Speech Translation Training
### Wait-K with fixed pre-decision module
Fixed pre-decision indicates that the model operate simultaneous policy on the boundaries of fixed chunks.
Here is a example of fixed pre-decision ratio 7 (the simultaneous decision is made every 7 encoder states) and
a wait-3 policy model. Assuming the save directory is `${ST_SAVE_DIR}`
```bash
fairseq-train ${MUSTC_ROOT}/en-de \
--config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \
--save-dir ${ST_SAVE_DIR} --num-workers 8 \
--optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \
--criterion label_smoothed_cross_entropy \
--warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/checkpoint_best.pt \
--task speech_to_text \
--arch convtransformer_simul_trans_espnet \
--simul-type waitk_fixed_pre_decision \
--waitk-lagging 3 \
--fixed-pre-decision-ratio 7 \
--update-freq 8
```
### Monotonic multihead attention with fixed pre-decision module
```
fairseq-train ${MUSTC_ROOT}/en-de \
--config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \
--save-dir ${ST_SAVE_DIR} --num-workers 8 \
--optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \
--warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \
--task speech_to_text \
--criterion latency_augmented_label_smoothed_cross_entropy \
--latency-weight-avg 0.1 \
--arch convtransformer_simul_trans_espnet \
--simul-type infinite_lookback_fixed_pre_decision \
--fixed-pre-decision-ratio 7 \
--update-freq 8
```
## Inference & Evaluation
[SimulEval](https://github.com/facebookresearch/SimulEval) is used for evaluation.
The following command is for evaluation.
```
git clone https://github.com/facebookresearch/SimulEval.git
cd SimulEval
pip install -e .
simuleval \
--agent ${FAIRSEQ}/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py
--source ${SRC_LIST_OF_AUDIO}
--target ${TGT_FILE}
--data-bin ${MUSTC_ROOT}/en-de \
--config config_st.yaml \
--model-path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
--output ${OUTPUT} \
--scores
```
The source file `${SRC_LIST_OF_AUDIO}` is a list of paths of audio files. Assuming your audio files stored at `/home/user/data`,
it should look like this
```bash
/home/user/data/audio-1.wav
/home/user/data/audio-2.wav
```
Each line of target file `${TGT_FILE}` is the translation for each audio file input.
```bash
Translation_1
Translation_2
```
The evaluation runs on the original MUSTC segmentation.
The following command will generate the wav list and text file for a evaluation set `${SPLIT}` (chose from `dev`, `tst-COMMON` and `tst-HE`) in MUSTC to `${EVAL_DATA}`.
```bash
python ${FAIRSEQ}/examples/speech_to_text/seg_mustc_data.py \
--data-root ${MUSTC_ROOT} --lang de \
--split ${SPLIT} --task st \
--output ${EVAL_DATA}
```
The `--data-bin` and `--config` should be the same in previous section if you prepare the data from the scratch.
If only for evaluation, a prepared data directory can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin.tgz). It contains
- `spm_unigram10000_st.model`: a sentencepiece model binary.
- `spm_unigram10000_st.txt`: the dictionary file generated by the sentencepiece model.
- `gcmvn.npz`: the binary for global cepstral mean and variance.
- `config_st.yaml`: the config yaml file. It looks like this.
You will need to set the absolute paths for `sentencepiece_model` and `stats_npz_path` if the data directory is downloaded.
```yaml
bpe_tokenizer:
bpe: sentencepiece
sentencepiece_model: ABS_PATH_TO_SENTENCEPIECE_MODEL
global_cmvn:
stats_npz_path: ABS_PATH_TO_GCMVN_FILE
input_channels: 1
input_feat_per_channel: 80
sampling_alpha: 1.0
specaugment:
freq_mask_F: 27
freq_mask_N: 1
time_mask_N: 1
time_mask_T: 100
time_mask_p: 1.0
time_wrap_W: 0
transforms:
'*':
- global_cmvn
_train:
- global_cmvn
- specaugment
vocab_filename: spm_unigram10000_st.txt
```
Notice that once a `--data-bin` is set, the `--config` is the base name of the config yaml, not the full path.
Set `--model-path` to the model checkpoint.
A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms.
The result of this model on `tst-COMMON` is:
```bash
{
"Quality": {
"BLEU": 13.94974229366959
},
"Latency": {
"AL": 1751.8031870037803,
"AL_CA": 2338.5911762796536,
"AP": 0.7931395378788959,
"AP_CA": 0.9405103863210942,
"DAL": 1987.7811616943081,
"DAL_CA": 2425.2751560926167
}
}
```
If `--output ${OUTPUT}` option is used, the detailed log and scores will be stored under the `${OUTPUT}` directory.
The quality is measured by detokenized BLEU. So make sure that the predicted words sent to the server are detokenized.
The latency metrics are
* Average Proportion
* Average Lagging
* Differentiable Average Lagging
Again they will also be evaluated on detokenized text.
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from pathlib import Path
import shutil
from tempfile import NamedTemporaryFile
from typing import Optional, Tuple
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
load_df_from_tsv,
save_df_to_tsv,
)
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive
from tqdm import tqdm
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class CoVoST(Dataset):
"""Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
Args:
root (str): root path to the dataset and generated manifests/features
source_language (str): source (audio) language
target_language (str, optional): target (text) language,
None for no translation (default: None)
version (int, optional): CoVoST version. (default: 2)
download (bool, optional): Whether to download the dataset if it is not
found at root path. (default: ``False``).
"""
COVOST_URL_TEMPLATE = (
"https://dl.fbaipublicfiles.com/covost/"
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
)
VERSIONS = {2}
SPLITS = ["train", "dev", "test"]
XX_EN_LANGUAGES = {
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
2: [
"fr",
"de",
"es",
"ca",
"it",
"ru",
"zh-CN",
"pt",
"fa",
"et",
"mn",
"nl",
"tr",
"ar",
"sv-SE",
"lv",
"sl",
"ta",
"ja",
"id",
"cy",
],
}
EN_XX_LANGUAGES = {
1: [],
2: [
"de",
"tr",
"fa",
"sv-SE",
"mn",
"zh-CN",
"cy",
"ca",
"sl",
"et",
"id",
"ar",
"ta",
"lv",
"ja",
],
}
def __init__(
self,
root: str,
split: str,
source_language: str,
target_language: Optional[str] = None,
version: int = 2,
) -> None:
assert version in self.VERSIONS and split in self.SPLITS
assert source_language is not None
self.no_translation = target_language is None
if not self.no_translation:
assert "en" in {source_language, target_language}
if source_language == "en":
assert target_language in self.EN_XX_LANGUAGES[version]
else:
assert source_language in self.XX_EN_LANGUAGES[version]
else:
# Hack here so that we can get "split" column from CoVoST TSV.
# Note that we use CoVoST train split for ASR which is an extension
# to Common Voice train split.
target_language = "de" if source_language == "en" else "en"
self.root: Path = Path(root)
cv_tsv_path = self.root / "validated.tsv"
assert cv_tsv_path.is_file()
covost_url = self.COVOST_URL_TEMPLATE.format(
src_lang=source_language, tgt_lang=target_language
)
covost_archive = self.root / Path(covost_url).name
if not covost_archive.is_file():
download_url(covost_url, self.root.as_posix(), hash_value=None)
extract_archive(covost_archive.as_posix())
cv_tsv = load_df_from_tsv(cv_tsv_path)
covost_tsv = load_df_from_tsv(
self.root / Path(covost_url).name.replace(".tar.gz", "")
)
df = pd.merge(
left=cv_tsv[["path", "sentence", "client_id"]],
right=covost_tsv[["path", "translation", "split"]],
how="inner",
on="path",
)
if split == "train":
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
else:
df = df[df["split"] == split]
data = df.to_dict(orient="index").items()
data = [v for k, v in sorted(data, key=lambda x: x[0])]
self.data = []
for e in data:
try:
path = self.root / "clips" / e["path"]
_ = torchaudio.info(path.as_posix())
self.data.append(e)
except RuntimeError:
pass
def __getitem__(
self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
sample_id)``
"""
data = self.data[n]
path = self.root / "clips" / data["path"]
waveform, sample_rate = torchaudio.load(path)
sentence = data["sentence"]
translation = None if self.no_translation else data["translation"]
speaker_id = data["client_id"]
_id = data["path"].replace(".mp3", "")
return waveform, sample_rate, sentence, translation, speaker_id, _id
def __len__(self) -> int:
return len(self.data)
def process(args):
root = Path(args.data_root).absolute() / args.src_lang
if not root.is_dir():
raise NotADirectoryError(f"{root} does not exist")
# Extract features
feature_root = root / "fbank80"
feature_root.mkdir(exist_ok=True)
for split in CoVoST.SPLITS:
print(f"Fetching split {split}...")
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(
waveform, sample_rate, feature_root / f"{utt_id}.npy"
)
# Pack features into ZIP
zip_path = root / "fbank80.zip"
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest
print("Generating manifest...")
train_text = []
task = f"asr_{args.src_lang}"
if args.tgt_lang is not None:
task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
manifest["speaker"].append(speaker_id)
is_train_split = split.startswith("train")
if is_train_split:
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, root / f"{split}_{task}.tsv")
# Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
root / spm_filename_prefix,
args.vocab_type,
args.vocab_size
)
# Generate config YAML
gen_config_yaml(
root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-root", "-d", required=True, type=str,
help="data root with sub-folders for each language <root>/<src_lang>"
)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=1000, type=int)
parser.add_argument("--src-lang", "-s", required=True, type=str)
parser.add_argument("--tgt-lang", "-t", type=str)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from pathlib import Path
import shutil
from tempfile import NamedTemporaryFile
import pandas as pd
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torchaudio.datasets import LIBRISPEECH
from tqdm import tqdm
log = logging.getLogger(__name__)
SPLITS = [
"train-clean-100",
"train-clean-360",
"train-other-500",
"dev-clean",
"dev-other",
"test-clean",
"test-other",
]
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
def process(args):
out_root = Path(args.output_root).absolute()
out_root.mkdir(exist_ok=True)
# Extract features
feature_root = out_root / "fbank80"
feature_root.mkdir(exist_ok=True)
for split in SPLITS:
print(f"Fetching split {split}...")
dataset = LIBRISPEECH(out_root.as_posix(), url=split, download=True)
print("Extracting log mel filter bank features...")
for wav, sample_rate, _, spk_id, chapter_no, utt_no in tqdm(dataset):
sample_id = f"{spk_id}-{chapter_no}-{utt_no}"
extract_fbank_features(
wav, sample_rate, feature_root / f"{sample_id}.npy"
)
# Pack features into ZIP
zip_path = out_root / "fbank80.zip"
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest
print("Generating manifest...")
train_text = []
for split in SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = LIBRISPEECH(out_root.as_posix(), url=split)
for wav, sample_rate, utt, spk_id, chapter_no, utt_no in tqdm(dataset):
sample_id = f"{spk_id}-{chapter_no}-{utt_no}"
manifest["id"].append(sample_id)
manifest["audio"].append(zip_manifest[sample_id])
duration_ms = int(wav.size(1) / sample_rate * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(utt.lower())
manifest["speaker"].append(spk_id)
save_df_to_tsv(
pd.DataFrame.from_dict(manifest), out_root / f"{split}.tsv"
)
if split.startswith("train"):
train_text.extend(manifest["tgt_text"])
# Generate vocab
vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
out_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
gen_config_yaml(
out_root, spm_filename_prefix + ".model", specaugment_policy="ld"
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output-root", "-o", required=True, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=10000, type=int)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
from pathlib import Path
import shutil
from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import pandas as pd
import soundfile as sf
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
load_df_from_tsv,
save_df_to_tsv,
)
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from fairseq.data.audio.audio_utils import get_waveform
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang"]
class mTEDx(Dataset):
"""
Create a Dataset for Multilingual TEDx.
Each item is a tuple of the form: waveform, sample_rate, source utterance,
target utterance, speaker_id, utterance_id
"""
SPLITS = ["train", "valid", "test"]
LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar", "de-de",
"es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es", "fr-pt",
"pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"]
def __init__(self, root: str, lang: str, split: str) -> None:
assert split in self.SPLITS and lang in self.LANGPAIRS
_root = Path(root) / f"{lang}" / "data" / split
wav_root, txt_root = _root / "wav", _root / "txt"
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir()
# Load audio segments
try:
import yaml
except ImportError:
print("Please install PyYAML to load the Multilingual TEDx YAML files")
with open(txt_root / f"{split}.yaml") as f:
segments = yaml.load(f, Loader=yaml.BaseLoader)
# Load source and target utterances
src, tgt = lang.split("-")
for _lang in [src, tgt]:
with open(txt_root / f"{split}.{_lang}") as f:
utterances = [r.strip() for r in f]
assert len(segments) == len(utterances)
for i, u in enumerate(utterances):
segments[i][_lang] = u
# Gather info
self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_filename = wav_filename.replace(".wav", ".flac")
wav_path = wav_root / wav_filename
sample_rate = sf.info(wav_path.as_posix()).samplerate
seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{wav_path.stem}_{i}"
self.data.append(
(
wav_path.as_posix(),
offset,
n_frames,
sample_rate,
segment[src],
segment[tgt],
segment["speaker_id"],
tgt,
_id,
)
)
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str, str]:
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id = self.data[n]
waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset)
waveform = torch.from_numpy(waveform)
return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id
def __len__(self) -> int:
return len(self.data)
def process(args):
root = Path(args.data_root).absolute()
for lang in mTEDx.LANGPAIRS:
cur_root = root / f"{lang}"
if not cur_root.is_dir():
print(f"{cur_root.as_posix()} does not exist. Skipped.")
continue
# Extract features
feature_root = cur_root / "fbank80"
feature_root.mkdir(exist_ok=True)
for split in mTEDx.SPLITS:
print(f"Fetching split {split}...")
dataset = mTEDx(root.as_posix(), lang, split)
print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(
waveform, sample_rate, feature_root / f"{utt_id}.npy"
)
# Pack features into ZIP
zip_path = cur_root / "fbank80.zip"
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest
print("Generating manifest...")
train_text = []
for split in mTEDx.SPLITS:
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = mTEDx(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, tgt_lang, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
manifest["speaker"].append(speaker_id)
manifest["tgt_lang"].append(tgt_lang)
if is_train_split:
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv")
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
cur_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
gen_config_yaml(
cur_root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{args.task}.yaml",
specaugment_policy="lb",
)
# Clean up
shutil.rmtree(feature_root)
def process_joint(args):
cur_root = Path(args.data_root)
assert all((cur_root / f"{lang}").is_dir() for lang in mTEDx.LANGPAIRS), \
"do not have downloaded data available for all languages"
# Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
with NamedTemporaryFile(mode="w") as f:
for lang in mTEDx.LANGPAIRS:
tsv_path = cur_root / f"{lang}" / f"train_{args.task}.tsv"
df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]:
f.write(t + "\n")
special_symbols = None
if args.joint:
# Add tgt_lang tags to dict
special_symbols = list({f'<lang:{lang.split("-")[1]}>' for lang in mTEDx.LANGPAIRS})
gen_vocab(
Path(f.name),
cur_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
special_symbols=special_symbols
)
# Generate config YAML
gen_config_yaml(
cur_root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{args.task}.yaml",
specaugment_policy="ld",
prepend_tgt_lang_tag=(args.joint),
)
# Make symbolic links to manifests
for lang in mTEDx.LANGPAIRS:
for split in mTEDx.SPLITS:
src_path = cur_root / f"{lang}" / f"{split}_{args.task}.tsv"
desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv"
if not desc_path.is_symlink():
os.symlink(src_path, desc_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=8000, type=int)
parser.add_argument("--task", type=str, choices=["asr", "st"])
parser.add_argument("--joint", action="store_true", help="")
args = parser.parse_args()
if args.joint:
process_joint(args)
else:
process(args)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
from pathlib import Path
import shutil
from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import numpy as np
import pandas as pd
import soundfile as sf
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
load_df_from_tsv,
save_df_to_tsv,
cal_gcmvn_stats,
)
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from fairseq.data.audio.audio_utils import get_waveform
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class MUSTC(Dataset):
"""
Create a Dataset for MuST-C. Each item is a tuple of the form:
waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id
"""
SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
_root = Path(root) / f"en-{lang}" / "data" / split
wav_root, txt_root = _root / "wav", _root / "txt"
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir()
# Load audio segments
try:
import yaml
except ImportError:
print("Please install PyYAML to load the MuST-C YAML files")
with open(txt_root / f"{split}.yaml") as f:
segments = yaml.load(f, Loader=yaml.BaseLoader)
# Load source and target utterances
for _lang in ["en", lang]:
with open(txt_root / f"{split}.{_lang}") as f:
utterances = [r.strip() for r in f]
assert len(segments) == len(utterances)
for i, u in enumerate(utterances):
segments[i][_lang] = u
# Gather info
self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = wav_root / wav_filename
sample_rate = sf.info(wav_path.as_posix()).samplerate
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{wav_path.stem}_{i}"
self.data.append(
(
wav_path.as_posix(),
offset,
n_frames,
sample_rate,
segment["en"],
segment[lang],
segment["speaker_id"],
_id,
)
)
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str]:
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset)
waveform = torch.from_numpy(waveform)
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
def __len__(self) -> int:
return len(self.data)
def process(args):
root = Path(args.data_root).absolute()
for lang in MUSTC.LANGUAGES:
cur_root = root / f"en-{lang}"
if not cur_root.is_dir():
print(f"{cur_root.as_posix()} does not exist. Skipped.")
continue
# Extract features
feature_root = cur_root / "fbank80"
feature_root.mkdir(exist_ok=True)
for split in MUSTC.SPLITS:
print(f"Fetching split {split}...")
dataset = MUSTC(root.as_posix(), lang, split)
print("Extracting log mel filter bank features...")
if split == 'train' and args.cmvn_type == "global":
print("And estimating cepstral mean and variance stats...")
gcmvn_feature_list = []
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
features = extract_fbank_features(waveform, sample_rate)
np.save(
(feature_root / f"{utt_id}.npy").as_posix(),
features
)
if split == 'train' and args.cmvn_type == "global":
if len(gcmvn_feature_list) < args.gcmvn_max_num:
gcmvn_feature_list.append(features)
if split == 'train' and args.cmvn_type == "global":
# Estimate and save cmv
stats = cal_gcmvn_stats(gcmvn_feature_list)
with open(cur_root / "gcmvn.npz", "wb") as f:
np.savez(f, mean=stats["mean"], std=stats["std"])
# Pack features into ZIP
zip_path = cur_root / "fbank80.zip"
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(zip_path)
# Generate TSV manifest
print("Generating manifest...")
train_text = []
for split in MUSTC.SPLITS:
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
manifest["speaker"].append(speaker_id)
if is_train_split:
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv")
# Generate vocab
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
Path(f.name),
cur_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
gen_config_yaml(
cur_root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{args.task}.yaml",
specaugment_policy="lb",
cmvn_type=args.cmvn_type,
gcmvn_path=(
cur_root / "gcmvn.npz" if args.cmvn_type == "global"
else None
),
)
# Clean up
shutil.rmtree(feature_root)
def process_joint(args):
cur_root = Path(args.data_root)
assert all((cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES), \
"do not have downloaded data available for all 8 languages"
# Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
with NamedTemporaryFile(mode="w") as f:
for lang in MUSTC.LANGUAGES:
tsv_path = cur_root / f"en-{lang}" / f"train_{args.task}.tsv"
df = load_df_from_tsv(tsv_path)
for t in df["tgt_text"]:
f.write(t + "\n")
special_symbols = None
if args.task == 'st':
special_symbols = [f'<lang:{lang}>' for lang in MUSTC.LANGUAGES]
gen_vocab(
Path(f.name),
cur_root / spm_filename_prefix,
args.vocab_type,
args.vocab_size,
special_symbols=special_symbols
)
# Generate config YAML
gen_config_yaml(
cur_root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{args.task}.yaml",
specaugment_policy="ld",
prepend_tgt_lang_tag=(args.task == "st"),
)
# Make symbolic links to manifests
for lang in MUSTC.LANGUAGES:
for split in MUSTC.SPLITS:
src_path = cur_root / f"en-{lang}" / f"{split}_{args.task}.tsv"
desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv"
if not desc_path.is_symlink():
os.symlink(src_path, desc_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=8000, type=int)
parser.add_argument("--task", type=str, choices=["asr", "st"])
parser.add_argument("--joint", action="store_true", help="")
parser.add_argument("--cmvn-type", default="utterance",
choices=["global", "utterance"],
help="The type of cepstral mean and variance normalization")
parser.add_argument("--gcmvn-max-num", default=150000, type=int,
help=(
"Maximum number of sentences to use to estimate"
"global mean and variance"
))
args = parser.parse_args()
if args.joint:
process_joint(args)
else:
process(args)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from pathlib import Path
import soundfile as sf
from examples.speech_to_text.prep_mustc_data import (
MUSTC
)
from tqdm import tqdm
log = logging.getLogger(__name__)
def main(args):
root = Path(args.data_root).absolute()
lang = args.lang
split = args.split
cur_root = root / f"en-{lang}"
assert cur_root.is_dir(), (
f"{cur_root.as_posix()} does not exist. Skipped."
)
dataset = MUSTC(root.as_posix(), lang, split)
output = Path(args.output).absolute()
output.mkdir(exist_ok=True)
f_text = open(output / f"{split}.{lang}", "w")
f_wav_list = open(output / f"{split}.wav_list", "w")
for waveform, sample_rate, _, text, _, utt_id in tqdm(dataset):
sf.write(
output / f"{utt_id}.wav",
waveform.squeeze(0).numpy(),
samplerate=int(sample_rate)
)
f_text.write(text + "\n")
f_wav_list.write(str(output / f"{utt_id}.wav") + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--task", required=True, type=str, choices=["asr", "st"])
parser.add_argument("--lang", required=True, type=str)
parser.add_argument("--output", required=True, type=str)
parser.add_argument("--split", required=True, choices=MUSTC.SPLITS)
args = parser.parse_args()
main(args)
import math
import os
import json
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
import yaml
from fairseq import checkpoint_utils, tasks
from fairseq.file_io import PathManager
try:
from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS
from simuleval.agents import SpeechAgent
from simuleval.states import ListEntry, SpeechStates
except ImportError:
print("Please install simuleval 'pip install simuleval'")
SHIFT_SIZE = 10
WINDOW_SIZE = 25
SAMPLE_RATE = 16000
FEATURE_DIM = 80
BOW_PREFIX = "\u2581"
class OnlineFeatureExtractor:
"""
Extract speech feature on the fly.
"""
def __init__(self, args):
self.shift_size = args.shift_size
self.window_size = args.window_size
assert self.window_size >= self.shift_size
self.sample_rate = args.sample_rate
self.feature_dim = args.feature_dim
self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000)
self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000)
self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000
self.previous_residual_samples = []
self.global_cmvn = args.global_cmvn
def clear_cache(self):
self.previous_residual_samples = []
def __call__(self, new_samples):
samples = self.previous_residual_samples + new_samples
if len(samples) < self.num_samples_per_window:
self.previous_residual_samples = samples
return
# num_frames is the number of frames from the new segment
num_frames = math.floor(
(len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size))
/ self.num_samples_per_shift
)
# the number of frames used for feature extraction
# including some part of thte previous segment
effective_num_samples = int(
num_frames * self.len_ms_to_samples(self.shift_size)
+ self.len_ms_to_samples(self.window_size - self.shift_size)
)
input_samples = samples[:effective_num_samples]
self.previous_residual_samples = samples[
num_frames * self.num_samples_per_shift:
]
torch.manual_seed(1)
output = kaldi.fbank(
torch.FloatTensor(input_samples).unsqueeze(0),
num_mel_bins=self.feature_dim,
frame_length=self.window_size,
frame_shift=self.shift_size,
).numpy()
output = self.transform(output)
return torch.from_numpy(output)
def transform(self, input):
if self.global_cmvn is None:
return input
mean = self.global_cmvn["mean"]
std = self.global_cmvn["std"]
x = np.subtract(input, mean)
x = np.divide(x, std)
return x
class TensorListEntry(ListEntry):
"""
Data structure to store a list of tensor.
"""
def append(self, value):
if len(self.value) == 0:
self.value = value
return
self.value = torch.cat([self.value] + [value], dim=0)
def info(self):
return {
"type": str(self.new_value_type),
"length": self.__len__(),
"value": "" if type(self.value) is list else self.value.size(),
}
class FairseqSimulSTAgent(SpeechAgent):
speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size
def __init__(self, args):
super().__init__(args)
self.eos = DEFAULT_EOS
self.gpu = getattr(args, "gpu", False)
self.args = args
self.load_model_vocab(args)
if getattr(
self.model.decoder.layers[0].encoder_attn,
'pre_decision_ratio',
None
) is not None:
self.speech_segment_size *= (
self.model.decoder.layers[0].encoder_attn.pre_decision_ratio
)
args.global_cmvn = None
if args.config:
with open(os.path.join(args.data_bin, args.config), "r") as f:
config = yaml.load(f, Loader=yaml.BaseLoader)
if "global_cmvn" in config:
args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"])
if args.global_stats:
with PathManager.open(args.global_stats, "r") as f:
global_cmvn = json.loads(f.read())
self.global_cmvn = {"mean": global_cmvn["mean"], "std": global_cmvn["stddev"]}
self.feature_extractor = OnlineFeatureExtractor(args)
self.max_len = args.max_len
self.force_finish = args.force_finish
torch.set_grad_enabled(False)
def build_states(self, args, client, sentence_id):
# Initialize states here, for example add customized entry to states
# This function will be called at beginning of every new sentence
states = SpeechStates(args, client, sentence_id, self)
self.initialize_states(states)
return states
def to_device(self, tensor):
if self.gpu:
return tensor.cuda()
else:
return tensor.cpu()
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--model-path', type=str, required=True,
help='path to your pretrained model.')
parser.add_argument("--data-bin", type=str, required=True,
help="Path of data binary")
parser.add_argument("--config", type=str, default=None,
help="Path to config yaml file")
parser.add_argument("--global-stats", type=str, default=None,
help="Path to json file containing cmvn stats")
parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece",
help="Subword splitter type for target text")
parser.add_argument("--tgt-splitter-path", type=str, default=None,
help="Subword splitter model path for target text")
parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation",
help="User directory for simultaneous translation")
parser.add_argument("--max-len", type=int, default=200,
help="Max length of translation")
parser.add_argument("--force-finish", default=False, action="store_true",
help="Force the model to finish the hypothsis if the source is not finished")
parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE,
help="Shift size of feature extraction window.")
parser.add_argument("--window-size", type=int, default=WINDOW_SIZE,
help="Window size of feature extraction window.")
parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE,
help="Sample rate")
parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM,
help="Acoustic feature dimension.")
# fmt: on
return parser
def load_model_vocab(self, args):
filename = args.model_path
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
task_args = state["cfg"]["task"]
task_args.data = args.data_bin
if args.config is not None:
task_args.config_yaml = args.config
task = tasks.setup_task(task_args)
# build model for ensemble
state["cfg"]["model"].load_pretrained_encoder_from = None
state["cfg"]["model"].load_pretrained_decoder_from = None
self.model = task.build_model(state["cfg"]["model"])
self.model.load_state_dict(state["model"], strict=True)
self.model.eval()
self.model.share_memory()
if self.gpu:
self.model.cuda()
# Set dictionary
self.dict = {}
self.dict["tgt"] = task.target_dictionary
def initialize_states(self, states):
self.feature_extractor.clear_cache()
states.units.source = TensorListEntry()
states.units.target = ListEntry()
states.incremental_states = dict()
def segment_to_units(self, segment, states):
# Convert speech samples to features
features = self.feature_extractor(segment)
if features is not None:
return [features]
else:
return []
def units_to_segment(self, units, states):
# Merge sub word to full word.
if self.model.decoder.dictionary.eos() == units[0]:
return DEFAULT_EOS
segment = []
if None in units.value:
units.value.remove(None)
for index in units:
if index is None:
units.pop()
token = self.model.decoder.dictionary.string([index])
if token.startswith(BOW_PREFIX):
if len(segment) == 0:
segment += [token.replace(BOW_PREFIX, "")]
else:
for j in range(len(segment)):
units.pop()
string_to_return = ["".join(segment)]
if self.model.decoder.dictionary.eos() == units[0]:
string_to_return += [DEFAULT_EOS]
return string_to_return
else:
segment += [token.replace(BOW_PREFIX, "")]
if (
len(units) > 0
and self.model.decoder.dictionary.eos() == units[-1]
or len(states.units.target) > self.max_len
):
tokens = [self.model.decoder.dictionary.string([unit]) for unit in units]
return ["".join(tokens).replace(BOW_PREFIX, "")] + [DEFAULT_EOS]
return None
def update_model_encoder(self, states):
if len(states.units.source) == 0:
return
src_indices = self.to_device(
states.units.source.value.unsqueeze(0)
)
src_lengths = self.to_device(
torch.LongTensor([states.units.source.value.size(0)])
)
states.encoder_states = self.model.encoder(src_indices, src_lengths)
torch.cuda.empty_cache()
def update_states_read(self, states):
# Happens after a read action.
self.update_model_encoder(states)
def policy(self, states):
if not getattr(states, "encoder_states", None):
return READ_ACTION
tgt_indices = self.to_device(
torch.LongTensor(
[self.model.decoder.dictionary.eos()]
+ [x for x in states.units.target.value if x is not None]
).unsqueeze(0)
)
states.incremental_states["steps"] = {
"src": states.encoder_states["encoder_out"][0].size(0),
"tgt": 1 + len(states.units.target),
}
states.incremental_states["online"] = {"only": torch.tensor(not states.finish_read())}
x, outputs = self.model.decoder.forward(
prev_output_tokens=tgt_indices,
encoder_out=states.encoder_states,
incremental_state=states.incremental_states,
)
states.decoder_out = x
states.decoder_out_extra = outputs
torch.cuda.empty_cache()
if outputs.action == 0:
return READ_ACTION
else:
return WRITE_ACTION
def predict(self, states):
decoder_states = states.decoder_out
lprobs = self.model.get_normalized_probs(
[decoder_states[:, -1:]], log_probs=True
)
index = lprobs.argmax(dim=-1)
index = index[0, 0].item()
if (
self.force_finish
and index == self.model.decoder.dictionary.eos()
and not states.finish_read()
):
# If we want to force finish the translation
# (don't stop before finish reading), return a None
# self.model.decoder.clear_cache(states.incremental_states)
index = None
return index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment