"tests/fixtures/custom_pipeline/pipeline.py" did not exist on "76d492ea49342b486dfbca1dbcdfbb052fe34112"
Commit 799a38c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #616 failed with stages
in 0 seconds
import importlib
import os
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module("examples.speech_recognition.models." + model_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import math
from collections.abc import Iterable
import torch
import torch.nn as nn
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqEncoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import (
LinearizedConvolution,
TransformerDecoderLayer,
TransformerEncoderLayer,
VGGBlock,
)
@register_model("asr_vggtransformer")
class VGGTransformerModel(FairseqEncoderDecoderModel):
"""
Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--vggblock-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one vggblock:
[(out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
use_layer_norm), ...])
""",
)
parser.add_argument(
"--transformer-enc-config",
type=str,
metavar="EXPR",
help=""""
a tuple containing the configuration of the encoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]')
""",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="""
encoder output dimension, can be None. If specified, projecting the
transformer output to the specified dimension""",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--tgt-embed-dim",
type=int,
metavar="N",
help="embedding dimension of the decoder target tokens",
)
parser.add_argument(
"--transformer-dec-config",
type=str,
metavar="EXPR",
help="""
a tuple containing the configuration of the decoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]
""",
)
parser.add_argument(
"--conv-dec-config",
type=str,
metavar="EXPR",
help="""
an array of tuples for the decoder 1-D convolution config
[(out_channels, conv_kernel_size, use_layer_norm), ...]""",
)
@classmethod
def build_encoder(cls, args, task):
return VGGTransformerEncoder(
input_feat_per_channel=args.input_feat_per_channel,
vggblock_config=eval(args.vggblock_enc_config),
transformer_config=eval(args.transformer_enc_config),
encoder_output_dim=args.enc_output_dim,
in_channels=args.in_channels,
)
@classmethod
def build_decoder(cls, args, task):
return TransformerDecoder(
dictionary=task.target_dictionary,
embed_dim=args.tgt_embed_dim,
transformer_config=eval(args.transformer_dec_config),
conv_config=eval(args.conv_dec_config),
encoder_output_dim=args.enc_output_dim,
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted
# (in case there are any new ones)
base_architecture(args)
encoder = cls.build_encoder(args, task)
decoder = cls.build_decoder(args, task)
return cls(encoder, decoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2
DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2
# 256: embedding dimension
# 4: number of heads
# 1024: FFN
# True: apply layerNorm before (dropout + resiaul) instead of after
# 0.2 (dropout): dropout after MultiheadAttention and second FC
# 0.2 (attention_dropout): dropout in MultiheadAttention
# 0.2 (relu_dropout): dropout after ReLu
DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2
DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2
# TODO: repace transformer encoder config from one liner
# to explicit args to get rid of this transformation
def prepare_transformer_encoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.encoder_embed_dim = input_dim
args.encoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.encoder_normalize_before = normalize_before
args.encoder_ffn_embed_dim = ffn_dim
return args
def prepare_transformer_decoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.encoder_embed_dim = None
args.decoder_embed_dim = input_dim
args.decoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.decoder_normalize_before = normalize_before
args.decoder_ffn_embed_dim = ffn_dim
return args
class VGGTransformerEncoder(FairseqEncoder):
"""VGG + Transformer encoder"""
def __init__(
self,
input_feat_per_channel,
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
encoder_output_dim=512,
in_channels=1,
transformer_context=None,
transformer_sampling=None,
):
"""constructor for VGGTransformerEncoder
Args:
- input_feat_per_channel: feature dim (not including stacked,
just base feature)
- in_channel: # input channels (e.g., if stack 8 feature vector
together, this is 8)
- vggblock_config: configuration of vggblock, see comments on
DEFAULT_ENC_VGGBLOCK_CONFIG
- transformer_config: configuration of transformer layer, see comments
on DEFAULT_ENC_TRANSFORMER_CONFIG
- encoder_output_dim: final transformer output embedding dimension
- transformer_context: (left, right) if set, self-attention will be focused
on (t-left, t+right)
- transformer_sampling: an iterable of int, must match with
len(transformer_config), transformer_sampling[i] indicates sampling
factor for i-th transformer layer, after multihead att and feedfoward
part
"""
super().__init__(None)
self.num_vggblocks = 0
if vggblock_config is not None:
if not isinstance(vggblock_config, Iterable):
raise ValueError("vggblock_config is not iterable")
self.num_vggblocks = len(vggblock_config)
self.conv_layers = nn.ModuleList()
self.in_channels = in_channels
self.input_dim = input_feat_per_channel
self.pooling_kernel_sizes = []
if vggblock_config is not None:
for _, config in enumerate(vggblock_config):
(
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
layer_norm,
) = config
self.conv_layers.append(
VGGBlock(
in_channels,
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
input_dim=input_feat_per_channel,
layer_norm=layer_norm,
)
)
self.pooling_kernel_sizes.append(pooling_kernel_size)
in_channels = out_channels
input_feat_per_channel = self.conv_layers[-1].output_dim
transformer_input_dim = self.infer_conv_output_dim(
self.in_channels, self.input_dim
)
# transformer_input_dim is the output dimension of VGG part
self.validate_transformer_config(transformer_config)
self.transformer_context = self.parse_transformer_context(transformer_context)
self.transformer_sampling = self.parse_transformer_sampling(
transformer_sampling, len(transformer_config)
)
self.transformer_layers = nn.ModuleList()
if transformer_input_dim != transformer_config[0][0]:
self.transformer_layers.append(
Linear(transformer_input_dim, transformer_config[0][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.transformer_layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[i])
)
)
self.encoder_output_dim = encoder_output_dim
self.transformer_layers.extend(
[
Linear(transformer_config[-1][0], encoder_output_dim),
LayerNorm(encoder_output_dim),
]
)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
bsz, max_seq_len, _ = src_tokens.size()
x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
x = x.transpose(1, 2).contiguous()
# (B, C, T, feat)
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
bsz, _, output_seq_len, _ = x.size()
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat)
x = x.transpose(1, 2).transpose(0, 1)
x = x.contiguous().view(output_seq_len, bsz, -1)
input_lengths = src_lengths.clone()
for s in self.pooling_kernel_sizes:
input_lengths = (input_lengths.float() / s).ceil().long()
encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
input_lengths, batch_first=True
)
if not encoder_padding_mask.any():
encoder_padding_mask = None
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
transformer_layer_idx = 0
for layer_idx in range(len(self.transformer_layers)):
if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer):
x = self.transformer_layers[layer_idx](
x, encoder_padding_mask, attn_mask
)
if self.transformer_sampling[transformer_layer_idx] != 1:
sampling_factor = self.transformer_sampling[transformer_layer_idx]
x, encoder_padding_mask, attn_mask = self.slice(
x, encoder_padding_mask, attn_mask, sampling_factor
)
transformer_layer_idx += 1
else:
x = self.transformer_layers[layer_idx](x)
# encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate
# whether encoder_output[t, b] is valid or not (valid=0, invalid=1)
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": encoder_padding_mask.t()
if encoder_padding_mask is not None
else None,
# (B, T) --> (T, B)
}
def infer_conv_output_dim(self, in_channels, input_dim):
sample_seq_len = 200
sample_bsz = 10
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
for i, _ in enumerate(self.conv_layers):
x = self.conv_layers[i](x)
x = x.transpose(1, 2)
mb, seq = x.size()[:2]
return x.contiguous().view(mb, seq, -1).size(-1)
def validate_transformer_config(self, transformer_config):
for config in transformer_config:
input_dim, num_heads = config[:2]
if input_dim % num_heads != 0:
msg = (
"ERROR in transformer config {}: ".format(config)
+ "input dimension {} ".format(input_dim)
+ "not dividable by number of heads {}".format(num_heads)
)
raise ValueError(msg)
def parse_transformer_context(self, transformer_context):
"""
transformer_context can be the following:
- None; indicates no context is used, i.e.,
transformer can access full context
- a tuple/list of two int; indicates left and right context,
any number <0 indicates infinite context
* e.g., (5, 6) indicates that for query at x_t, transformer can
access [t-5, t+6] (inclusive)
* e.g., (-1, 6) indicates that for query at x_t, transformer can
access [0, t+6] (inclusive)
"""
if transformer_context is None:
return None
if not isinstance(transformer_context, Iterable):
raise ValueError("transformer context must be Iterable if it is not None")
if len(transformer_context) != 2:
raise ValueError("transformer context must have length 2")
left_context = transformer_context[0]
if left_context < 0:
left_context = None
right_context = transformer_context[1]
if right_context < 0:
right_context = None
if left_context is None and right_context is None:
return None
return (left_context, right_context)
def parse_transformer_sampling(self, transformer_sampling, num_layers):
"""
parsing transformer sampling configuration
Args:
- transformer_sampling, accepted input:
* None, indicating no sampling
* an Iterable with int (>0) as element
- num_layers, expected number of transformer layers, must match with
the length of transformer_sampling if it is not None
Returns:
- A tuple with length num_layers
"""
if transformer_sampling is None:
return (1,) * num_layers
if not isinstance(transformer_sampling, Iterable):
raise ValueError(
"transformer_sampling must be an iterable if it is not None"
)
if len(transformer_sampling) != num_layers:
raise ValueError(
"transformer_sampling {} does not match with the number "
"of layers {}".format(transformer_sampling, num_layers)
)
for layer, value in enumerate(transformer_sampling):
if not isinstance(value, int):
raise ValueError("Invalid value in transformer_sampling: ")
if value < 1:
raise ValueError(
"{} layer's subsampling is {}.".format(layer, value)
+ " This is not allowed! "
)
return transformer_sampling
def slice(self, embedding, padding_mask, attn_mask, sampling_factor):
"""
embedding is a (T, B, D) tensor
padding_mask is a (B, T) tensor or None
attn_mask is a (T, T) tensor or None
"""
embedding = embedding[::sampling_factor, :, :]
if padding_mask is not None:
padding_mask = padding_mask[:, ::sampling_factor]
if attn_mask is not None:
attn_mask = attn_mask[::sampling_factor, ::sampling_factor]
return embedding, padding_mask, attn_mask
def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1):
"""
create attention mask according to sequence lengths and transformer
context
Args:
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is
the length of b-th sequence
- subsampling_factor: int
* Note that the left_context and right_context is specified in
the input frame-level while input to transformer may already
go through subsampling (e.g., the use of striding in vggblock)
we use subsampling_factor to scale the left/right context
Return:
- a (T, T) binary tensor or None, where T is max(input_lengths)
* if self.transformer_context is None, None
* if left_context is None,
* attn_mask[t, t + right_context + 1:] = 1
* others = 0
* if right_context is None,
* attn_mask[t, 0:t - left_context] = 1
* others = 0
* elsif
* attn_mask[t, t - left_context: t + right_context + 1] = 0
* others = 1
"""
if self.transformer_context is None:
return None
maxT = torch.max(input_lengths).item()
attn_mask = torch.zeros(maxT, maxT)
left_context = self.transformer_context[0]
right_context = self.transformer_context[1]
if left_context is not None:
left_context = math.ceil(self.transformer_context[0] / subsampling_factor)
if right_context is not None:
right_context = math.ceil(self.transformer_context[1] / subsampling_factor)
for t in range(maxT):
if left_context is not None:
st = 0
en = max(st, t - left_context)
attn_mask[t, st:en] = 1
if right_context is not None:
st = t + right_context + 1
st = min(st, maxT - 1)
attn_mask[t, st:] = 1
return attn_mask.to(input_lengths.device)
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
class TransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
"""
def __init__(
self,
dictionary,
embed_dim=512,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
conv_config=DEFAULT_DEC_CONV_CONFIG,
encoder_output_dim=512,
):
super().__init__(dictionary)
vocab_size = len(dictionary)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx)
self.conv_layers = nn.ModuleList()
for i in range(len(conv_config)):
out_channels, kernel_size, layer_norm = conv_config[i]
if i == 0:
conv_layer = LinearizedConv1d(
embed_dim, out_channels, kernel_size, padding=kernel_size - 1
)
else:
conv_layer = LinearizedConv1d(
conv_config[i - 1][0],
out_channels,
kernel_size,
padding=kernel_size - 1,
)
self.conv_layers.append(conv_layer)
if layer_norm:
self.conv_layers.append(nn.LayerNorm(out_channels))
self.conv_layers.append(nn.ReLU())
self.layers = nn.ModuleList()
if conv_config[-1][0] != transformer_config[0][0]:
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
self.layers.append(
TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.layers.append(
TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[i])
)
)
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
target_padding_mask = (
(prev_output_tokens == self.padding_idx).to(prev_output_tokens.device)
if incremental_state is None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
# embed tokens
x = self.embed_tokens(prev_output_tokens)
# B x T x C -> T x B x C
x = self._transpose_if_training(x, incremental_state)
for layer in self.conv_layers:
if isinstance(layer, LinearizedConvolution):
x = layer(x, incremental_state)
else:
x = layer(x)
# B x T x C -> T x B x C
x = self._transpose_if_inference(x, incremental_state)
# decoder layers
for layer in self.layers:
if isinstance(layer, TransformerDecoderLayer):
x, *_ = layer(
x,
(encoder_out["encoder_out"] if encoder_out is not None else None),
(
encoder_out["encoder_padding_mask"].t()
if encoder_out["encoder_padding_mask"] is not None
else None
),
incremental_state,
self_attn_mask=(
self.buffered_future_mask(x)
if incremental_state is None
else None
),
self_attn_padding_mask=(
target_padding_mask if incremental_state is None else None
),
)
else:
x = layer(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
x = self.fc_out(x)
return x, None
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def _transpose_if_training(self, x, incremental_state):
if incremental_state is None:
x = x.transpose(0, 1)
return x
def _transpose_if_inference(self, x, incremental_state):
if incremental_state:
x = x.transpose(0, 1)
return x
@register_model("asr_vggtransformer_encoder")
class VGGTransformerEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--vggblock-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one vggblock
[(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...]
""",
)
parser.add_argument(
"--transformer-enc-config",
type=str,
metavar="EXPR",
help="""
a tuple containing the configuration of the Transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ]""",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="encoder output dimension, projecting the LSTM output",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--transformer-context",
type=str,
metavar="EXPR",
help="""
either None or a tuple of two ints, indicating left/right context a
transformer can have access to""",
)
parser.add_argument(
"--transformer-sampling",
type=str,
metavar="EXPR",
help="""
either None or a tuple of ints, indicating sampling factor in each layer""",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
base_architecture_enconly(args)
encoder = VGGTransformerEncoderOnly(
vocab_size=len(task.target_dictionary),
input_feat_per_channel=args.input_feat_per_channel,
vggblock_config=eval(args.vggblock_enc_config),
transformer_config=eval(args.transformer_enc_config),
encoder_output_dim=args.enc_output_dim,
in_channels=args.in_channels,
transformer_context=eval(args.transformer_context),
transformer_sampling=eval(args.transformer_sampling),
)
return cls(encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (T, B, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
# lprobs is a (T, B, D) tensor
# we need to transoose to get (B, T, D) tensor
lprobs = lprobs.transpose(0, 1).contiguous()
lprobs.batch_first = True
return lprobs
class VGGTransformerEncoderOnly(VGGTransformerEncoder):
def __init__(
self,
vocab_size,
input_feat_per_channel,
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
encoder_output_dim=512,
in_channels=1,
transformer_context=None,
transformer_sampling=None,
):
super().__init__(
input_feat_per_channel=input_feat_per_channel,
vggblock_config=vggblock_config,
transformer_config=transformer_config,
encoder_output_dim=encoder_output_dim,
in_channels=in_channels,
transformer_context=transformer_context,
transformer_sampling=transformer_sampling,
)
self.fc_out = Linear(self.encoder_output_dim, vocab_size)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
enc_out = super().forward(src_tokens, src_lengths)
x = self.fc_out(enc_out["encoder_out"])
# x = F.log_softmax(x, dim=-1)
# Note: no need this line, because model.get_normalized_prob will call
# log_softmax
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B)
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
return (1e6, 1e6) # an arbitrary large number
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
# nn.init.uniform_(m.weight, -0.1, 0.1)
# nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True, dropout=0):
"""Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
# m.weight.data.uniform_(-0.1, 0.1)
# if bias:
# m.bias.data.uniform_(-0.1, 0.1)
return m
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m, dim=2)
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
# seq2seq models
def base_architecture(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.in_channels = getattr(args, "in_channels", 1)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.transformer_dec_config = getattr(
args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG)
args.transformer_context = getattr(args, "transformer_context", "None")
@register_model_architecture("asr_vggtransformer", "vggtransformer_1")
def vggtransformer_1(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_2")
def vggtransformer_2(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_base")
def vggtransformer_base(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12"
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6"
)
# Size estimations:
# Encoder:
# - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K
# Transformer:
# - input dimension adapter: 2560 x 512 -> 1.31M
# - transformer_layers (x12) --> 37.74M
# * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M
# * FFN weight: 512*2048*2 = 2.097M
# - output dimension adapter: 512 x 512 -> 0.26 M
# Decoder:
# - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3
# - transformer_layer: (x6) --> 25.16M
# * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M
# * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M
# * FFN: 512*2048*2 = 2.097M
# Final FC:
# - FC: 512*5000 = 256K (assuming vocab size 5K)
# In total:
# ~65 M
# CTC models
def base_architecture_enconly(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2"
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2"
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.in_channels = getattr(args, "in_channels", 1)
args.transformer_context = getattr(args, "transformer_context", "None")
args.transformer_sampling = getattr(args, "transformer_sampling", "None")
@register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1")
def vggtransformer_enc_1(args):
# vggtransformer_1 is the same as vggtransformer_enc_big, except the number
# of layers is increased to 16
# keep it here for backward compatiablity purpose
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules.fairseq_dropout import FairseqDropout
default_conv_enc_config = """[
(400, 13, 170, 0.2),
(440, 14, 0, 0.214),
(484, 15, 0, 0.22898),
(532, 16, 0, 0.2450086),
(584, 17, 0, 0.262159202),
(642, 18, 0, 0.28051034614),
(706, 19, 0, 0.30014607037),
(776, 20, 0, 0.321156295296),
(852, 21, 0, 0.343637235966),
(936, 22, 0, 0.367691842484),
(1028, 23, 0, 0.393430271458),
(1130, 24, 0, 0.42097039046),
(1242, 25, 0, 0.450438317792),
(1366, 26, 0, 0.481969000038),
(1502, 27, 0, 0.51570683004),
(1652, 28, 0, 0.551806308143),
(1816, 29, 0, 0.590432749713),
]"""
@register_model("asr_w2l_conv_glu_encoder")
class W2lConvGluEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--conv-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one conv layer
[(out_channels, kernel_size, padding, dropout), ...]
""",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
encoder = W2lConvGluEncoder(
vocab_size=len(task.target_dictionary),
input_feat_per_channel=args.input_feat_per_channel,
in_channels=args.in_channels,
conv_enc_config=eval(conv_enc_config),
)
return cls(encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = False
return lprobs
class W2lConvGluEncoder(FairseqEncoder):
def __init__(
self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config
):
super().__init__(None)
self.input_dim = input_feat_per_channel
if in_channels != 1:
raise ValueError("only 1 input channel is currently supported")
self.conv_layers = nn.ModuleList()
self.linear_layers = nn.ModuleList()
self.dropouts = []
cur_channels = input_feat_per_channel
for out_channels, kernel_size, padding, dropout in conv_enc_config:
layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding)
layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init
self.conv_layers.append(nn.utils.weight_norm(layer))
self.dropouts.append(
FairseqDropout(dropout, module_name=self.__class__.__name__)
)
if out_channels % 2 != 0:
raise ValueError("odd # of out_channels is incompatible with GLU")
cur_channels = out_channels // 2 # halved by GLU
for out_channels in [2 * cur_channels, vocab_size]:
layer = nn.Linear(cur_channels, out_channels)
layer.weight.data.mul_(math.sqrt(3))
self.linear_layers.append(nn.utils.weight_norm(layer))
cur_channels = out_channels // 2
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
B, T, _ = src_tokens.size()
x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
x = F.glu(x, dim=1)
x = self.dropouts[layer_idx](x)
x = x.transpose(1, 2).contiguous() # (B, T, 908)
x = self.linear_layers[0](x)
x = F.glu(x, dim=2)
x = self.dropouts[-1](x)
x = self.linear_layers[1](x)
assert x.size(0) == B
assert x.size(1) == T
encoder_out = x.transpose(0, 1) # (T, B, vocab_size)
# need to debug this -- find a simpler/elegant way in pytorch APIs
encoder_padding_mask = (
torch.arange(T).view(1, T).expand(B, -1).to(x.device)
>= src_lengths.view(B, 1).expand(-1, T)
).t() # (B x T) -> (T x B)
return {
"encoder_out": encoder_out, # (T, B, vocab_size)
"encoder_padding_mask": encoder_padding_mask, # (T, B)
}
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return (1e6, 1e6) # an arbitrary large number
@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc")
def w2l_conv_glu_enc(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.in_channels = getattr(args, "in_channels", 1)
args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
# Flashlight Decoder
This script runs decoding for pre-trained speech recognition models.
## Usage
Assuming a few variables:
```bash
checkpoint=<path-to-checkpoint>
data=<path-to-data-directory>
lm_model=<path-to-language-model>
lexicon=<path-to-lexicon>
```
Example usage for decoding a fine-tuned Wav2Vec model:
```bash
python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
task=audio_pretraining \
task.data=$data \
task.labels=ltr \
common_eval.path=$checkpoint \
decoding.type=kenlm \
decoding.lexicon=$lexicon \
decoding.lmpath=$lm_model \
dataset.gen_subset=dev_clean,dev_other,test_clean,test_other
```
Example usage for using Ax to sweep WER parameters (requires `pip install hydra-ax-sweeper`):
```bash
python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
hydra/sweeper=ax \
task=audio_pretraining \
task.data=$data \
task.labels=ltr \
common_eval.path=$checkpoint \
decoding.type=kenlm \
decoding.lexicon=$lexicon \
decoding.lmpath=$lm_model \
dataset.gen_subset=dev_other
```
# @package hydra.sweeper
_target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper
max_batch_size: null
ax_config:
max_trials: 128
early_stop:
minimize: true
max_epochs_without_improvement: 32
epsilon: 1.0e-05
experiment:
name: ${dataset.gen_subset}
objective_name: wer
minimize: true
parameter_constraints: null
outcome_constraints: null
status_quo: null
client:
verbose_logging: false
random_seed: null
params:
decoding.lmweight:
type: range
bounds: [0.0, 5.0]
decoding.wordscore:
type: range
bounds: [-5.0, 5.0]
# @package _group_
defaults:
- task: null
- model: null
hydra:
run:
dir: ${common_eval.results_path}/${dataset.gen_subset}
sweep:
dir: ${common_eval.results_path}
subdir: ${dataset.gen_subset}
common_eval:
results_path: null
path: null
post_process: letter
quiet: true
dataset:
max_tokens: 1000000
gen_subset: test
distributed_training:
distributed_world_size: 1
decoding:
beam: 5
type: viterbi
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools as it
from typing import Any, Dict, List
import torch
from fairseq.data.dictionary import Dictionary
from fairseq.models.fairseq_model import FairseqModel
class BaseDecoder:
def __init__(self, tgt_dict: Dictionary) -> None:
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
if "<sep>" in tgt_dict.indices:
self.silence = tgt_dict.index("<sep>")
elif "|" in tgt_dict.indices:
self.silence = tgt_dict.index("|")
else:
self.silence = tgt_dict.eos()
def generate(
self, models: List[FairseqModel], sample: Dict[str, Any], **unused
) -> List[List[Dict[str, torch.LongTensor]]]:
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(
self,
models: List[FairseqModel],
encoder_input: Dict[str, Any],
) -> torch.FloatTensor:
model = models[0]
encoder_out = model(**encoder_input)
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out)
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
def decode(
self,
emissions: torch.FloatTensor,
) -> List[List[Dict[str, torch.LongTensor]]]:
raise NotImplementedError
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
from fairseq.data.dictionary import Dictionary
from .decoder_config import DecoderConfig, FlashlightDecoderConfig
from .base_decoder import BaseDecoder
def Decoder(
cfg: Union[DecoderConfig, FlashlightDecoderConfig], tgt_dict: Dictionary
) -> BaseDecoder:
if cfg.type == "viterbi":
from .viterbi_decoder import ViterbiDecoder
return ViterbiDecoder(tgt_dict)
if cfg.type == "kenlm":
from .flashlight_decoder import KenLMDecoder
return KenLMDecoder(cfg, tgt_dict)
if cfg.type == "fairseqlm":
from .flashlight_decoder import FairseqLMDecoder
return FairseqLMDecoder(cfg, tgt_dict)
raise NotImplementedError(f"Invalid decoder name: {cfg.name}")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import Optional
from fairseq.dataclass.configs import FairseqDataclass
from fairseq.dataclass.constants import ChoiceEnum
from omegaconf import MISSING
DECODER_CHOICES = ChoiceEnum(["viterbi", "kenlm", "fairseqlm"])
@dataclass
class DecoderConfig(FairseqDataclass):
type: DECODER_CHOICES = field(
default="viterbi",
metadata={"help": "The type of decoder to use"},
)
@dataclass
class FlashlightDecoderConfig(FairseqDataclass):
nbest: int = field(
default=1,
metadata={"help": "Number of decodings to return"},
)
unitlm: bool = field(
default=False,
metadata={"help": "If set, use unit language model"},
)
lmpath: str = field(
default=MISSING,
metadata={"help": "Language model for KenLM decoder"},
)
lexicon: Optional[str] = field(
default=None,
metadata={"help": "Lexicon for Flashlight decoder"},
)
beam: int = field(
default=50,
metadata={"help": "Number of beams to use for decoding"},
)
beamthreshold: float = field(
default=50.0,
metadata={"help": "Threshold for beam search decoding"},
)
beamsizetoken: Optional[int] = field(
default=None, metadata={"help": "Beam size to use"}
)
wordscore: float = field(
default=-1,
metadata={"help": "Word score for KenLM decoder"},
)
unkweight: float = field(
default=-math.inf,
metadata={"help": "Unknown weight for KenLM decoder"},
)
silweight: float = field(
default=0,
metadata={"help": "Silence weight for KenLM decoder"},
)
lmweight: float = field(
default=2,
metadata={"help": "Weight for LM while interpolating score"},
)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import gc
import os.path as osp
import warnings
from collections import deque, namedtuple
from typing import Any, Dict, Tuple
import numpy as np
import torch
from fairseq import tasks
from fairseq.data.dictionary import Dictionary
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models.fairseq_model import FairseqModel
from fairseq.utils import apply_to_sample
from omegaconf import open_dict, OmegaConf
from typing import List
from .decoder_config import FlashlightDecoderConfig
from .base_decoder import BaseDecoder
try:
from flashlight.lib.text.decoder import (
LM,
CriterionType,
DecodeResult,
KenLM,
LexiconDecoder,
LexiconDecoderOptions,
LexiconFreeDecoder,
LexiconFreeDecoderOptions,
LMState,
SmearingMode,
Trie,
)
from flashlight.lib.text.dictionary import create_word_dict, load_words
except ImportError:
warnings.warn(
"flashlight python bindings are required to use this functionality. "
"Please install from "
"https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
)
LM = object
LMState = object
class KenLMDecoder(BaseDecoder):
def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None:
super().__init__(tgt_dict)
self.nbest = cfg.nbest
self.unitlm = cfg.unitlm
if cfg.lexicon:
self.lexicon = load_words(cfg.lexicon)
self.word_dict = create_word_dict(self.lexicon)
self.unk_word = self.word_dict.get_index("<unk>")
self.lm = KenLM(cfg.lmpath, self.word_dict)
self.trie = Trie(self.vocab_size, self.silence)
start_state = self.lm.start(False)
for word, spellings in self.lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = self.lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{word} {spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = LexiconDecoderOptions(
beam_size=cfg.beam,
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
beam_threshold=cfg.beamthreshold,
lm_weight=cfg.lmweight,
word_score=cfg.wordscore,
unk_score=cfg.unkweight,
sil_score=cfg.silweight,
log_add=False,
criterion_type=CriterionType.CTC,
)
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
[],
self.unitlm,
)
else:
assert self.unitlm, "Lexicon-free decoding requires unit LM"
d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
self.lm = KenLM(cfg.lmpath, self.word_dict)
self.decoder_opts = LexiconFreeDecoderOptions(
beam_size=cfg.beam,
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
beam_threshold=cfg.beamthreshold,
lm_weight=cfg.lmweight,
sil_score=cfg.silweight,
log_add=False,
criterion_type=CriterionType.CTC,
)
self.decoder = LexiconFreeDecoder(
self.decoder_opts, self.lm, self.silence, self.blank, []
)
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.
Parameters
----------
token_idxs : List[int]
IDs of decoded tokens.
Returns
-------
List[int]
Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank:
continue
if i == 0 or token_idx != token_idxs[i-1]:
timesteps.append(i)
return timesteps
def decode(
self,
emissions: torch.FloatTensor,
) -> List[List[Dict[str, torch.LongTensor]]]:
B, T, N = emissions.size()
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append(
[
{
"tokens": self.get_tokens(result.tokens),
"score": result.score,
"timesteps": self.get_timesteps(result.tokens),
"words": [
self.word_dict.get_entry(x) for x in result.words if x >= 0
],
}
for result in nbest_results
]
)
return hypos
FairseqLMState = namedtuple(
"FairseqLMState",
[
"prefix",
"incremental_state",
"probs",
],
)
class FairseqLM(LM):
def __init__(self, dictionary: Dictionary, model: FairseqModel) -> None:
super().__init__()
self.dictionary = dictionary
self.model = model
self.unk = self.dictionary.unk()
self.save_incremental = False # this currently does not work properly
self.max_cache = 20_000
if torch.cuda.is_available():
model.cuda()
model.eval()
model.make_generation_fast_()
self.states = {}
self.stateq = deque()
def start(self, start_with_nothing: bool) -> LMState:
state = LMState()
prefix = torch.LongTensor([[self.dictionary.eos()]])
incremental_state = {} if self.save_incremental else None
with torch.no_grad():
res = self.model(prefix.cuda(), incremental_state=incremental_state)
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
if incremental_state is not None:
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
self.states[state] = FairseqLMState(
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
)
self.stateq.append(state)
return state
def score(
self,
state: LMState,
token_index: int,
no_cache: bool = False,
) -> Tuple[LMState, int]:
"""
Evaluate language model based on the current lm state and new word
Parameters:
-----------
state: current lm state
token_index: index of the word
(can be lexicon index then you should store inside LM the
mapping between indices of lexicon and lm, or lm index of a word)
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
curr_state = self.states[state]
def trim_cache(targ_size: int) -> None:
while len(self.stateq) > targ_size:
rem_k = self.stateq.popleft()
rem_st = self.states[rem_k]
rem_st = FairseqLMState(rem_st.prefix, None, None)
self.states[rem_k] = rem_st
if curr_state.probs is None:
new_incremental_state = (
curr_state.incremental_state.copy()
if curr_state.incremental_state is not None
else None
)
with torch.no_grad():
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cuda(), new_incremental_state
)
elif self.save_incremental:
new_incremental_state = {}
res = self.model(
torch.from_numpy(curr_state.prefix).cuda(),
incremental_state=new_incremental_state,
)
probs = self.model.get_normalized_probs(
res, log_probs=True, sample=None
)
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cpu(), new_incremental_state
)
curr_state = FairseqLMState(
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
)
if not no_cache:
self.states[state] = curr_state
self.stateq.append(state)
score = curr_state.probs[token_index].item()
trim_cache(self.max_cache)
outstate = state.child(token_index)
if outstate not in self.states and not no_cache:
prefix = np.concatenate(
[curr_state.prefix, torch.LongTensor([[token_index]])], -1
)
incr_state = curr_state.incremental_state
self.states[outstate] = FairseqLMState(prefix, incr_state, None)
if token_index == self.unk:
score = float("-inf")
return outstate, score
def finish(self, state: LMState) -> Tuple[LMState, int]:
"""
Evaluate eos for language model based on the current lm state
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
return self.score(state, self.dictionary.eos())
def empty_cache(self) -> None:
self.states = {}
self.stateq = deque()
gc.collect()
class FairseqLMDecoder(BaseDecoder):
def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None:
super().__init__(tgt_dict)
self.nbest = cfg.nbest
self.unitlm = cfg.unitlm
self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None
self.idx_to_wrd = {}
checkpoint = torch.load(cfg.lmpath, map_location="cpu")
if "cfg" in checkpoint and checkpoint["cfg"] is not None:
lm_args = checkpoint["cfg"]
else:
lm_args = convert_namespace_to_omegaconf(checkpoint["args"])
if not OmegaConf.is_dict(lm_args):
lm_args = OmegaConf.create(lm_args)
with open_dict(lm_args.task):
lm_args.task.data = osp.dirname(cfg.lmpath)
task = tasks.setup_task(lm_args.task)
model = task.build_model(lm_args.model)
model.load_state_dict(checkpoint["model"], strict=False)
self.trie = Trie(self.vocab_size, self.silence)
self.word_dict = task.dictionary
self.unk_word = self.word_dict.unk()
self.lm = FairseqLM(self.word_dict, model)
if self.lexicon:
start_state = self.lm.start(False)
for i, (word, spellings) in enumerate(self.lexicon.items()):
if self.unitlm:
word_idx = i
self.idx_to_wrd[i] = word
score = 0
else:
word_idx = self.word_dict.index(word)
_, score = self.lm.score(start_state, word_idx, no_cache=True)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = LexiconDecoderOptions(
beam_size=cfg.beam,
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
beam_threshold=cfg.beamthreshold,
lm_weight=cfg.lmweight,
word_score=cfg.wordscore,
unk_score=cfg.unkweight,
sil_score=cfg.silweight,
log_add=False,
criterion_type=CriterionType.CTC,
)
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
[],
self.unitlm,
)
else:
assert self.unitlm, "Lexicon-free decoding requires unit LM"
d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
self.lm = KenLM(cfg.lmpath, self.word_dict)
self.decoder_opts = LexiconFreeDecoderOptions(
beam_size=cfg.beam,
beam_size_token=cfg.beamsizetoken or len(tgt_dict),
beam_threshold=cfg.beamthreshold,
lm_weight=cfg.lmweight,
sil_score=cfg.silweight,
log_add=False,
criterion_type=CriterionType.CTC,
)
self.decoder = LexiconFreeDecoder(
self.decoder_opts, self.lm, self.silence, self.blank, []
)
def decode(
self,
emissions: torch.FloatTensor,
) -> List[List[Dict[str, torch.LongTensor]]]:
B, T, N = emissions.size()
hypos = []
def make_hypo(result: DecodeResult) -> Dict[str, Any]:
hypo = {
"tokens": self.get_tokens(result.tokens),
"score": result.score,
}
if self.lexicon:
hypo["words"] = [
self.idx_to_wrd[x] if self.unitlm else self.word_dict[x]
for x in result.words
if x >= 0
]
return hypo
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append([make_hypo(result) for result in nbest_results])
self.lm.empty_cache()
return hypos
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from typing import List, Dict
from .base_decoder import BaseDecoder
class ViterbiDecoder(BaseDecoder):
def decode(
self,
emissions: torch.FloatTensor,
) -> List[List[Dict[str, torch.LongTensor]]]:
def get_pred(e):
toks = e.argmax(dim=-1).unique_consecutive()
return toks[toks != self.blank]
return [[{"tokens": get_pred(x), "score": 0}] for x in emissions]
#!/usr/bin/env python -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import hashlib
import logging
import os
import shutil
import sys
from dataclasses import dataclass, field, is_dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import editdistance
import torch
import torch.distributed as dist
from examples.speech_recognition.new.decoders.decoder_config import (
DecoderConfig,
FlashlightDecoderConfig,
)
from examples.speech_recognition.new.decoders.decoder import Decoder
from fairseq import checkpoint_utils, distributed_utils, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.dataclass.configs import (
CheckpointConfig,
CommonConfig,
CommonEvalConfig,
DatasetConfig,
DistributedTrainingConfig,
FairseqDataclass,
)
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.logging.progress_bar import BaseProgressBar
from fairseq.models.fairseq_model import FairseqModel
from omegaconf import OmegaConf
import hydra
from hydra.core.config_store import ConfigStore
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
config_path = Path(__file__).resolve().parent / "conf"
@dataclass
class DecodingConfig(DecoderConfig, FlashlightDecoderConfig):
unique_wer_file: bool = field(
default=False,
metadata={"help": "If set, use a unique file for storing WER"},
)
results_path: Optional[str] = field(
default=None,
metadata={
"help": "If set, write hypothesis and reference sentences into this directory"
},
)
@dataclass
class InferConfig(FairseqDataclass):
task: Any = None
decoding: DecodingConfig = DecodingConfig()
common: CommonConfig = CommonConfig()
common_eval: CommonEvalConfig = CommonEvalConfig()
checkpoint: CheckpointConfig = CheckpointConfig()
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
dataset: DatasetConfig = DatasetConfig()
is_ax: bool = field(
default=False,
metadata={
"help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
},
)
def reset_logging():
root = logging.getLogger()
for handler in root.handlers:
root.removeHandler(handler)
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
root.addHandler(handler)
class InferenceProcessor:
cfg: InferConfig
def __init__(self, cfg: InferConfig) -> None:
self.cfg = cfg
self.task = tasks.setup_task(cfg.task)
models, saved_cfg = self.load_model_ensemble()
self.models = models
self.saved_cfg = saved_cfg
self.tgt_dict = self.task.target_dictionary
self.task.load_dataset(
self.cfg.dataset.gen_subset,
task_cfg=saved_cfg.task,
)
self.generator = Decoder(cfg.decoding, self.tgt_dict)
self.gen_timer = StopwatchMeter()
self.wps_meter = TimeMeter()
self.num_sentences = 0
self.total_errors = 0
self.total_length = 0
self.hypo_words_file = None
self.hypo_units_file = None
self.ref_words_file = None
self.ref_units_file = None
self.progress_bar = self.build_progress_bar()
def __enter__(self) -> "InferenceProcessor":
if self.cfg.decoding.results_path is not None:
self.hypo_words_file = self.get_res_file("hypo.word")
self.hypo_units_file = self.get_res_file("hypo.units")
self.ref_words_file = self.get_res_file("ref.word")
self.ref_units_file = self.get_res_file("ref.units")
return self
def __exit__(self, *exc) -> bool:
if self.cfg.decoding.results_path is not None:
self.hypo_words_file.close()
self.hypo_units_file.close()
self.ref_words_file.close()
self.ref_units_file.close()
return False
def __iter__(self) -> Any:
for sample in self.progress_bar:
if not self.cfg.common.cpu:
sample = utils.move_to_cuda(sample)
# Happens on the last batch.
if "net_input" not in sample:
continue
yield sample
def log(self, *args, **kwargs):
self.progress_bar.log(*args, **kwargs)
def print(self, *args, **kwargs):
self.progress_bar.print(*args, **kwargs)
def get_res_file(self, fname: str) -> None:
fname = os.path.join(self.cfg.decoding.results_path, fname)
if self.data_parallel_world_size > 1:
fname = f"{fname}.{self.data_parallel_rank}"
return open(fname, "w", buffering=1)
def merge_shards(self) -> None:
"""Merges all shard files into shard 0, then removes shard suffix."""
shard_id = self.data_parallel_rank
num_shards = self.data_parallel_world_size
if self.data_parallel_world_size > 1:
def merge_shards_with_root(fname: str) -> None:
fname = os.path.join(self.cfg.decoding.results_path, fname)
logger.info("Merging %s on shard %d", fname, shard_id)
base_fpath = Path(f"{fname}.0")
with open(base_fpath, "a") as out_file:
for s in range(1, num_shards):
shard_fpath = Path(f"{fname}.{s}")
with open(shard_fpath, "r") as in_file:
for line in in_file:
out_file.write(line)
shard_fpath.unlink()
shutil.move(f"{fname}.0", fname)
dist.barrier() # ensure all shards finished writing
if shard_id == (0 % num_shards):
merge_shards_with_root("hypo.word")
if shard_id == (1 % num_shards):
merge_shards_with_root("hypo.units")
if shard_id == (2 % num_shards):
merge_shards_with_root("ref.word")
if shard_id == (3 % num_shards):
merge_shards_with_root("ref.units")
dist.barrier()
def optimize_model(self, model: FairseqModel) -> None:
model.make_generation_fast_()
if self.cfg.common.fp16:
model.half()
if not self.cfg.common.cpu:
model.cuda()
def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]:
arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(self.cfg.common_eval.path, separator="\\"),
arg_overrides=arg_overrides,
task=self.task,
suffix=self.cfg.checkpoint.checkpoint_suffix,
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=self.cfg.checkpoint.checkpoint_shard_count,
)
for model in models:
self.optimize_model(model)
return models, saved_cfg
def get_dataset_itr(self, disable_iterator_cache: bool = False) -> None:
return self.task.get_batch_iterator(
dataset=self.task.dataset(self.cfg.dataset.gen_subset),
max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
seed=self.cfg.common.seed,
num_shards=self.data_parallel_world_size,
shard_id=self.data_parallel_rank,
num_workers=self.cfg.dataset.num_workers,
data_buffer_size=self.cfg.dataset.data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
).next_epoch_itr(shuffle=False)
def build_progress_bar(
self,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
default_log_format: str = "tqdm",
) -> BaseProgressBar:
return progress_bar.progress_bar(
iterator=self.get_dataset_itr(),
log_format=self.cfg.common.log_format,
log_interval=self.cfg.common.log_interval,
epoch=epoch,
prefix=prefix,
tensorboard_logdir=self.cfg.common.tensorboard_logdir,
default_log_format=default_log_format,
)
@property
def data_parallel_world_size(self):
if self.cfg.distributed_training.distributed_world_size == 1:
return 1
return distributed_utils.get_data_parallel_world_size()
@property
def data_parallel_rank(self):
if self.cfg.distributed_training.distributed_world_size == 1:
return 0
return distributed_utils.get_data_parallel_rank()
def process_sentence(
self,
sample: Dict[str, Any],
hypo: Dict[str, Any],
sid: int,
batch_id: int,
) -> Tuple[int, int]:
speaker = None # Speaker can't be parsed from dataset.
if "target_label" in sample:
toks = sample["target_label"]
else:
toks = sample["target"]
toks = toks[batch_id, :]
# Processes hypothesis.
hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu())
if "words" in hypo:
hyp_words = " ".join(hypo["words"])
else:
hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process)
# Processes target.
target_tokens = utils.strip_pad(toks, self.tgt_dict.pad())
tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu())
tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process)
if self.cfg.decoding.results_path is not None:
print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file)
print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file)
print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file)
print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file)
if not self.cfg.common_eval.quiet:
logger.info(f"HYPO: {hyp_words}")
logger.info(f"REF: {tgt_words}")
logger.info("---------------------")
hyp_words, tgt_words = hyp_words.split(), tgt_words.split()
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def process_sample(self, sample: Dict[str, Any]) -> None:
self.gen_timer.start()
hypos = self.task.inference_step(
generator=self.generator,
models=self.models,
sample=sample,
)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
self.gen_timer.stop(num_generated_tokens)
self.wps_meter.update(num_generated_tokens)
for batch_id, sample_id in enumerate(sample["id"].tolist()):
errs, length = self.process_sentence(
sample=sample,
sid=sample_id,
batch_id=batch_id,
hypo=hypos[batch_id][0],
)
self.total_errors += errs
self.total_length += length
self.log({"wps": round(self.wps_meter.avg)})
if "nsentences" in sample:
self.num_sentences += sample["nsentences"]
else:
self.num_sentences += sample["id"].numel()
def log_generation_time(self) -> None:
logger.info(
"Processed %d sentences (%d tokens) in %.1fs %.2f "
"sentences per second, %.2f tokens per second)",
self.num_sentences,
self.gen_timer.n,
self.gen_timer.sum,
self.num_sentences / self.gen_timer.sum,
1.0 / self.gen_timer.avg,
)
def parse_wer(wer_file: Path) -> float:
with open(wer_file, "r") as f:
return float(f.readline().strip().split(" ")[1])
def get_wer_file(cfg: InferConfig) -> Path:
"""Hashes the decoding parameters to a unique file ID."""
base_path = "wer"
if cfg.decoding.results_path is not None:
base_path = os.path.join(cfg.decoding.results_path, base_path)
if cfg.decoding.unique_wer_file:
yaml_str = OmegaConf.to_yaml(cfg.decoding)
fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16)
return Path(f"{base_path}.{fid % 1000000}")
else:
return Path(base_path)
def main(cfg: InferConfig) -> float:
"""Entry point for main processing logic.
Args:
cfg: The inferance configuration to use.
wer: Optional shared memory pointer for returning the WER. If not None,
the final WER value will be written here instead of being returned.
Returns:
The final WER if `wer` is None, otherwise None.
"""
yaml_str, wer_file = OmegaConf.to_yaml(cfg.decoding), get_wer_file(cfg)
# Validates the provided configuration.
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.max_tokens = 4000000
if not cfg.common.cpu and not torch.cuda.is_available():
raise ValueError("CUDA not found; set `cpu=True` to run without CUDA")
with InferenceProcessor(cfg) as processor:
for sample in processor:
processor.process_sample(sample)
processor.log_generation_time()
if cfg.decoding.results_path is not None:
processor.merge_shards()
errs_t, leng_t = processor.total_errors, processor.total_length
if cfg.common.cpu:
logger.warning("Merging WER requires CUDA.")
elif processor.data_parallel_world_size > 1:
stats = torch.LongTensor([errs_t, leng_t]).cuda()
dist.all_reduce(stats, op=dist.ReduceOp.SUM)
errs_t, leng_t = stats[0].item(), stats[1].item()
wer = errs_t * 100.0 / leng_t
if distributed_utils.is_master(cfg.distributed_training):
with open(wer_file, "w") as f:
f.write(
(
f"WER: {wer}\n"
f"err / num_ref_words = {errs_t} / {leng_t}\n\n"
f"{yaml_str}"
)
)
return wer
@hydra.main(config_path=config_path, config_name="infer")
def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]:
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
cfg = OmegaConf.create(container)
OmegaConf.set_struct(cfg, True)
if cfg.common.reset_logging:
reset_logging()
# logger.info("Config:\n%s", OmegaConf.to_yaml(cfg))
wer = float("inf")
try:
if cfg.common.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(cfg, main)
else:
distributed_utils.call_main(cfg, main)
wer = parse_wer(get_wer_file(cfg))
except BaseException as e: # pylint: disable=broad-except
if not cfg.common.suppress_crashes:
raise
else:
logger.error("Crashed! %s", str(e))
logger.info("Word error rate: %.4f", wer)
if cfg.is_ax:
return wer, None
return wer
def cli_main() -> None:
try:
from hydra._internal.utils import (
get_args,
) # pylint: disable=import-outside-toplevel
cfg_name = get_args().config_name or "infer"
except ImportError:
logger.warning("Failed to get config name from hydra args")
cfg_name = "infer"
cs = ConfigStore.instance()
cs.store(name=cfg_name, node=InferConfig)
for k in InferConfig.__dataclass_fields__:
if is_dataclass(InferConfig.__dataclass_fields__[k].type):
v = InferConfig.__dataclass_fields__[k].default
cs.store(name=k, node=v)
hydra_main() # pylint: disable=no-value-for-parameter
if __name__ == "__main__":
cli_main()
import importlib
import os
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
task_name = file[: file.find(".py")]
importlib.import_module("examples.speech_recognition.tasks." + task_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import re
import sys
import torch
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
from fairseq.data import Dictionary
from fairseq.tasks import LegacyFairseqTask, register_task
def get_asr_dataset_from_json(data_json_path, tgt_dict):
"""
Parse data json and create dataset.
See scripts/asr_prep_json.py which pack json from raw files
Json example:
{
"utts": {
"4771-29403-0025": {
"input": {
"length_ms": 170,
"path": "/tmp/file1.flac"
},
"output": {
"text": "HELLO \n",
"token": "HE LLO",
"tokenid": "4815, 861"
}
},
"1564-142299-0096": {
...
}
}
"""
if not os.path.isfile(data_json_path):
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
with open(data_json_path, "rb") as f:
data_samples = json.load(f)["utts"]
assert len(data_samples) != 0
sorted_samples = sorted(
data_samples.items(),
key=lambda sample: int(sample[1]["input"]["length_ms"]),
reverse=True,
)
aud_paths = [s[1]["input"]["path"] for s in sorted_samples]
ids = [s[0] for s in sorted_samples]
speakers = []
for s in sorted_samples:
m = re.search("(.+?)-(.+?)-(.+?)", s[0])
speakers.append(m.group(1) + "_" + m.group(2))
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
tgt = [
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
for s in sorted_samples
]
# append eos
tgt = [[*t, tgt_dict.eos()] for t in tgt]
return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
@register_task("speech_recognition")
class SpeechRecognitionTask(LegacyFairseqTask):
"""
Task for training speech recognition model.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("data", help="path to data directory")
parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
)
parser.add_argument(
"--max-source-positions",
default=sys.maxsize,
type=int,
metavar="N",
help="max number of frames in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, tgt_dict):
super().__init__(args)
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries)."""
dict_path = os.path.join(args.data, "dict.txt")
if not os.path.isfile(dict_path):
raise FileNotFoundError("Dict not found: {}".format(dict_path))
tgt_dict = Dictionary.load(dict_path)
if args.criterion == "ctc_loss":
tgt_dict.add_symbol("<ctc_blank>")
elif args.criterion == "asg_loss":
for i in range(1, args.max_replabel + 1):
tgt_dict.add_symbol(replabel_symbol(i))
print("| dictionary: {} types".format(len(tgt_dict)))
return cls(args, tgt_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
def build_generator(self, models, args, **unused):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, self.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, self.target_dictionary)
elif w2l_decoder == "fairseqlm":
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
return W2lFairseqLMDecoder(args, self.target_dictionary)
else:
return super().build_generator(models, args)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.tgt_dict
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
return None
def max_positions(self):
"""Return the max speech and sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import re
from collections import deque
from enum import Enum
import numpy as np
"""
Utility modules for computation of Word Error Rate,
Alignments, as well as more granular metrics like
deletion, insersion and substitutions.
"""
class Code(Enum):
match = 1
substitution = 2
insertion = 3
deletion = 4
class Token(object):
def __init__(self, lbl="", st=np.nan, en=np.nan):
if np.isnan(st):
self.label, self.start, self.end = "", 0.0, 0.0
else:
self.label, self.start, self.end = lbl, st, en
class AlignmentResult(object):
def __init__(self, refs, hyps, codes, score):
self.refs = refs # std::deque<int>
self.hyps = hyps # std::deque<int>
self.codes = codes # std::deque<Code>
self.score = score # float
def coordinate_to_offset(row, col, ncols):
return int(row * ncols + col)
def offset_to_row(offset, ncols):
return int(offset / ncols)
def offset_to_col(offset, ncols):
return int(offset % ncols)
def trimWhitespace(str):
return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str)))
def str2toks(str):
pieces = trimWhitespace(str).split(" ")
toks = []
for p in pieces:
toks.append(Token(p, 0.0, 0.0))
return toks
class EditDistance(object):
def __init__(self, time_mediated):
self.time_mediated_ = time_mediated
self.scores_ = np.nan # Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>
self.backtraces_ = (
np.nan
) # Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic> backtraces_;
self.confusion_pairs_ = {}
def cost(self, ref, hyp, code):
if self.time_mediated_:
if code == Code.match:
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end)
elif code == Code.insertion:
return hyp.end - hyp.start
elif code == Code.deletion:
return ref.end - ref.start
else: # substitution
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1
else:
if code == Code.match:
return 0
elif code == Code.insertion or code == Code.deletion:
return 3
else: # substitution
return 4
def get_result(self, refs, hyps):
res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan)
num_rows, num_cols = self.scores_.shape
res.score = self.scores_[num_rows - 1, num_cols - 1]
curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
while curr_offset != 0:
curr_row = offset_to_row(curr_offset, num_cols)
curr_col = offset_to_col(curr_offset, num_cols)
prev_offset = self.backtraces_[curr_row, curr_col]
prev_row = offset_to_row(prev_offset, num_cols)
prev_col = offset_to_col(prev_offset, num_cols)
res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++
res.hyps.appendleft(curr_col - 1)
if curr_row - 1 == prev_row and curr_col == prev_col:
res.codes.appendleft(Code.deletion)
elif curr_row == prev_row and curr_col - 1 == prev_col:
res.codes.appendleft(Code.insertion)
else:
# assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
ref_str = refs[res.refs[0]].label
hyp_str = hyps[res.hyps[0]].label
if ref_str == hyp_str:
res.codes.appendleft(Code.match)
else:
res.codes.appendleft(Code.substitution)
confusion_pair = "%s -> %s" % (ref_str, hyp_str)
if confusion_pair not in self.confusion_pairs_:
self.confusion_pairs_[confusion_pair] = 1
else:
self.confusion_pairs_[confusion_pair] += 1
curr_offset = prev_offset
return res
def align(self, refs, hyps):
if len(refs) == 0 and len(hyps) == 0:
return np.nan
# NOTE: we're not resetting the values in these matrices because every value
# will be overridden in the loop below. If this assumption doesn't hold,
# be sure to set all entries in self.scores_ and self.backtraces_ to 0.
self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1))
self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1))
num_rows, num_cols = self.scores_.shape
for i in range(num_rows):
for j in range(num_cols):
if i == 0 and j == 0:
self.scores_[i, j] = 0.0
self.backtraces_[i, j] = 0
continue
if i == 0:
self.scores_[i, j] = self.scores_[i, j - 1] + self.cost(
None, hyps[j - 1], Code.insertion
)
self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols)
continue
if j == 0:
self.scores_[i, j] = self.scores_[i - 1, j] + self.cost(
refs[i - 1], None, Code.deletion
)
self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols)
continue
# Below here both i and j are greater than 0
ref = refs[i - 1]
hyp = hyps[j - 1]
best_score = self.scores_[i - 1, j - 1] + (
self.cost(ref, hyp, Code.match)
if (ref.label == hyp.label)
else self.cost(ref, hyp, Code.substitution)
)
prev_row = i - 1
prev_col = j - 1
ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion)
if ins < best_score:
best_score = ins
prev_row = i
prev_col = j - 1
delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion)
if delt < best_score:
best_score = delt
prev_row = i - 1
prev_col = j
self.scores_[i, j] = best_score
self.backtraces_[i, j] = coordinate_to_offset(
prev_row, prev_col, num_cols
)
return self.get_result(refs, hyps)
class WERTransformer(object):
def __init__(self, hyp_str, ref_str, verbose=True):
self.ed_ = EditDistance(False)
self.id2oracle_errs_ = {}
self.utts_ = 0
self.words_ = 0
self.insertions_ = 0
self.deletions_ = 0
self.substitutions_ = 0
self.process(["dummy_str", hyp_str, ref_str])
if verbose:
print("'%s' vs '%s'" % (hyp_str, ref_str))
self.report_result()
def process(self, input): # std::vector<std::string>&& input
if len(input) < 3:
print(
"Input must be of the form <id> ... <hypo> <ref> , got ",
len(input),
" inputs:",
)
return None
# Align
# std::vector<Token> hyps;
# std::vector<Token> refs;
hyps = str2toks(input[-2])
refs = str2toks(input[-1])
alignment = self.ed_.align(refs, hyps)
if alignment is None:
print("Alignment is null")
return np.nan
# Tally errors
ins = 0
dels = 0
subs = 0
for code in alignment.codes:
if code == Code.substitution:
subs += 1
elif code == Code.insertion:
ins += 1
elif code == Code.deletion:
dels += 1
# Output
row = input
row.append(str(len(refs)))
row.append(str(ins))
row.append(str(dels))
row.append(str(subs))
# print(row)
# Accumulate
kIdIndex = 0
kNBestSep = "/"
pieces = input[kIdIndex].split(kNBestSep)
if len(pieces) == 0:
print(
"Error splitting ",
input[kIdIndex],
" on '",
kNBestSep,
"', got empty list",
)
return np.nan
id = pieces[0]
if id not in self.id2oracle_errs_:
self.utts_ += 1
self.words_ += len(refs)
self.insertions_ += ins
self.deletions_ += dels
self.substitutions_ += subs
self.id2oracle_errs_[id] = [ins, dels, subs]
else:
curr_err = ins + dels + subs
prev_err = np.sum(self.id2oracle_errs_[id])
if curr_err < prev_err:
self.id2oracle_errs_[id] = [ins, dels, subs]
return 0
def report_result(self):
# print("---------- Summary ---------------")
if self.words_ == 0:
print("No words counted")
return
# 1-best
best_wer = (
100.0
* (self.insertions_ + self.deletions_ + self.substitutions_)
/ self.words_
)
print(
"\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, "
"%0.2f%% dels, %0.2f%% subs)"
% (
best_wer,
self.utts_,
self.words_,
100.0 * self.insertions_ / self.words_,
100.0 * self.deletions_ / self.words_,
100.0 * self.substitutions_ / self.words_,
)
)
def wer(self):
if self.words_ == 0:
wer = np.nan
else:
wer = (
100.0
* (self.insertions_ + self.deletions_ + self.substitutions_)
/ self.words_
)
return wer
def stats(self):
if self.words_ == 0:
stats = {}
else:
wer = (
100.0
* (self.insertions_ + self.deletions_ + self.substitutions_)
/ self.words_
)
stats = dict(
{
"wer": wer,
"utts": self.utts_,
"numwords": self.words_,
"ins": self.insertions_,
"dels": self.deletions_,
"subs": self.substitutions_,
"confusion_pairs": self.ed_.confusion_pairs_,
}
)
return stats
def calc_wer(hyp_str, ref_str):
t = WERTransformer(hyp_str, ref_str, verbose=0)
return t.wer()
def calc_wer_stats(hyp_str, ref_str):
t = WERTransformer(hyp_str, ref_str, verbose=0)
return t.stats()
def get_wer_alignment_codes(hyp_str, ref_str):
"""
INPUT: hypothesis string, reference string
OUTPUT: List of alignment codes (intermediate results from WER computation)
"""
t = WERTransformer(hyp_str, ref_str, verbose=0)
return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes
def merge_counts(x, y):
# Merge two hashes which have 'counts' as their values
# This can be used for example to merge confusion pair counts
# conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs'])
for k, v in y.items():
if k not in x:
x[k] = 0
x[k] += v
return x
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Flashlight decoders.
"""
import gc
import itertools as it
import os.path as osp
from typing import List
import warnings
from collections import deque, namedtuple
import numpy as np
import torch
from examples.speech_recognition.data.replabels import unpack_replabels
from fairseq import tasks
from fairseq.utils import apply_to_sample
from omegaconf import open_dict
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
try:
from flashlight.lib.text.dictionary import create_word_dict, load_words
from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from flashlight.lib.text.decoder import (
CriterionType,
LexiconDecoderOptions,
KenLM,
LM,
LMState,
SmearingMode,
Trie,
LexiconDecoder,
)
except:
warnings.warn(
"flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
)
LM = object
LMState = object
class W2lDecoder(object):
def __init__(self, args, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = args.nbest
# criterion-specific init
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
if "<sep>" in tgt_dict.indices:
self.silence = tgt_dict.index("<sep>")
elif "|" in tgt_dict.indices:
self.silence = tgt_dict.index("|")
else:
self.silence = tgt_dict.eos()
self.asg_transitions = None
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
model = models[0]
encoder_out = model(**encoder_input)
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out) # no need to normalize emissions
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]
class W2lKenLMDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.unit_lm = getattr(args, "unit_lm", False)
if args.lexicon:
self.lexicon = load_words(args.lexicon)
self.word_dict = create_word_dict(self.lexicon)
self.unk_word = self.word_dict.get_index("<unk>")
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.trie = Trie(self.vocab_size, self.silence)
start_state = self.lm.start(False)
for i, (word, spellings) in enumerate(self.lexicon.items()):
word_idx = self.word_dict.get_index(word)
_, score = self.lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = LexiconDecoderOptions(
beam_size=args.beam,
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
word_score=args.word_score,
unk_score=args.unk_weight,
sil_score=args.sil_weight,
log_add=False,
criterion_type=self.criterion_type,
)
if self.asg_transitions is None:
N = 768
# self.asg_transitions = torch.FloatTensor(N, N).zero_()
self.asg_transitions = []
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
self.asg_transitions,
self.unit_lm,
)
else:
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.decoder_opts = LexiconFreeDecoderOptions(
beam_size=args.beam,
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
sil_score=args.sil_weight,
log_add=False,
criterion_type=self.criterion_type,
)
self.decoder = LexiconFreeDecoder(
self.decoder_opts, self.lm, self.silence, self.blank, []
)
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.
Parameters
----------
token_idxs : List[int]
IDs of decoded tokens.
Returns
-------
List[int]
Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank:
continue
if i == 0 or token_idx != token_idxs[i-1]:
timesteps.append(i)
return timesteps
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append(
[
{
"tokens": self.get_tokens(result.tokens),
"score": result.score,
"timesteps": self.get_timesteps(result.tokens),
"words": [
self.word_dict.get_entry(x) for x in result.words if x >= 0
],
}
for result in nbest_results
]
)
return hypos
FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"])
class FairseqLM(LM):
def __init__(self, dictionary, model):
LM.__init__(self)
self.dictionary = dictionary
self.model = model
self.unk = self.dictionary.unk()
self.save_incremental = False # this currently does not work properly
self.max_cache = 20_000
model.cuda()
model.eval()
model.make_generation_fast_()
self.states = {}
self.stateq = deque()
def start(self, start_with_nothing):
state = LMState()
prefix = torch.LongTensor([[self.dictionary.eos()]])
incremental_state = {} if self.save_incremental else None
with torch.no_grad():
res = self.model(prefix.cuda(), incremental_state=incremental_state)
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
if incremental_state is not None:
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
self.states[state] = FairseqLMState(
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
)
self.stateq.append(state)
return state
def score(self, state: LMState, token_index: int, no_cache: bool = False):
"""
Evaluate language model based on the current lm state and new word
Parameters:
-----------
state: current lm state
token_index: index of the word
(can be lexicon index then you should store inside LM the
mapping between indices of lexicon and lm, or lm index of a word)
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
curr_state = self.states[state]
def trim_cache(targ_size):
while len(self.stateq) > targ_size:
rem_k = self.stateq.popleft()
rem_st = self.states[rem_k]
rem_st = FairseqLMState(rem_st.prefix, None, None)
self.states[rem_k] = rem_st
if curr_state.probs is None:
new_incremental_state = (
curr_state.incremental_state.copy()
if curr_state.incremental_state is not None
else None
)
with torch.no_grad():
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cuda(), new_incremental_state
)
elif self.save_incremental:
new_incremental_state = {}
res = self.model(
torch.from_numpy(curr_state.prefix).cuda(),
incremental_state=new_incremental_state,
)
probs = self.model.get_normalized_probs(
res, log_probs=True, sample=None
)
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cpu(), new_incremental_state
)
curr_state = FairseqLMState(
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
)
if not no_cache:
self.states[state] = curr_state
self.stateq.append(state)
score = curr_state.probs[token_index].item()
trim_cache(self.max_cache)
outstate = state.child(token_index)
if outstate not in self.states and not no_cache:
prefix = np.concatenate(
[curr_state.prefix, torch.LongTensor([[token_index]])], -1
)
incr_state = curr_state.incremental_state
self.states[outstate] = FairseqLMState(prefix, incr_state, None)
if token_index == self.unk:
score = float("-inf")
return outstate, score
def finish(self, state: LMState):
"""
Evaluate eos for language model based on the current lm state
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
return self.score(state, self.dictionary.eos())
def empty_cache(self):
self.states = {}
self.stateq = deque()
gc.collect()
class W2lFairseqLMDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.unit_lm = getattr(args, "unit_lm", False)
self.lexicon = load_words(args.lexicon) if args.lexicon else None
self.idx_to_wrd = {}
checkpoint = torch.load(args.kenlm_model, map_location="cpu")
if "cfg" in checkpoint and checkpoint["cfg"] is not None:
lm_args = checkpoint["cfg"]
else:
lm_args = convert_namespace_to_omegaconf(checkpoint["args"])
with open_dict(lm_args.task):
lm_args.task.data = osp.dirname(args.kenlm_model)
task = tasks.setup_task(lm_args.task)
model = task.build_model(lm_args.model)
model.load_state_dict(checkpoint["model"], strict=False)
self.trie = Trie(self.vocab_size, self.silence)
self.word_dict = task.dictionary
self.unk_word = self.word_dict.unk()
self.lm = FairseqLM(self.word_dict, model)
if self.lexicon:
start_state = self.lm.start(False)
for i, (word, spellings) in enumerate(self.lexicon.items()):
if self.unit_lm:
word_idx = i
self.idx_to_wrd[i] = word
score = 0
else:
word_idx = self.word_dict.index(word)
_, score = self.lm.score(start_state, word_idx, no_cache=True)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = LexiconDecoderOptions(
beam_size=args.beam,
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
word_score=args.word_score,
unk_score=args.unk_weight,
sil_score=args.sil_weight,
log_add=False,
criterion_type=self.criterion_type,
)
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
[],
self.unit_lm,
)
else:
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.decoder_opts = LexiconFreeDecoderOptions(
beam_size=args.beam,
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
sil_score=args.sil_weight,
log_add=False,
criterion_type=self.criterion_type,
)
self.decoder = LexiconFreeDecoder(
self.decoder_opts, self.lm, self.silence, self.blank, []
)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
def idx_to_word(idx):
if self.unit_lm:
return self.idx_to_wrd[idx]
else:
return self.word_dict[idx]
def make_hypo(result):
hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
if self.lexicon:
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
return hypo
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append([make_hypo(result) for result in nbest_results])
self.lm.empty_cache()
return hypos
Speech Synthesis (S^2)
===
Speech synthesis with fairseq.
- Autoregressive and non-autoregressive models
- Multi-speaker synthesis
- Audio preprocessing
- Automatic metrics
- Similar data configuration as [S2T](../speech_to_text/README.md)
## Examples
- [Single-speaker synthesis on LJSpeech](docs/ljspeech_example.md)
- [Multi-speaker synthesis on VCTK](docs/vctk_example.md)
- [Multi-speaker synthesis on Common Voice](docs/common_voice_example.md)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment