Commit 18d27e00 authored by wangwei990215's avatar wangwei990215
Browse files

initial commit

parent 541f4c7a
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
lengths_to_mask,
)
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules import MultiheadAttention
from fairseq.utils import convert_padding_direction
from . import register_monotonic_attention
@with_incremental_state
class MonotonicAttention(nn.Module):
"""
Abstract class of monotonic attentions
"""
def __init__(self, args):
self.eps = args.attention_eps
self.mass_preservation = args.mass_preservation
self.noise_mean = args.noise_mean
self.noise_var = args.noise_var
self.energy_bias_init = args.energy_bias_init
self.energy_bias = (
nn.Parameter(self.energy_bias_init * torch.ones([1]))
if args.energy_bias is True
else 0
)
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--no-mass-preservation', action="store_false", dest="mass_preservation",
help='Do not stay on the last token when decoding')
parser.add_argument('--mass-preservation', action="store_true", dest="mass_preservation",
help='Stay on the last token when decoding')
parser.set_defaults(mass_preservation=True)
parser.add_argument('--noise-var', type=float, default=1.0,
help='Variance of discretness noise')
parser.add_argument('--noise-mean', type=float, default=0.0,
help='Mean of discretness noise')
parser.add_argument('--energy-bias', action="store_true", default=False,
help='Bias for energy')
parser.add_argument('--energy-bias-init', type=float, default=-2.0,
help='Initial value of the bias for energy')
parser.add_argument('--attention-eps', type=float, default=1e-6,
help='Epsilon when calculating expected attention')
# fmt: on
def p_choose(self, *args):
raise NotImplementedError
def input_projections(self, *args):
raise NotImplementedError
def attn_energy(self, q_proj, k_proj, key_padding_mask=None):
"""
Calculating monotonic energies
============================================================
Expected input size
q_proj: bsz * num_heads, tgt_len, self.head_dim
k_proj: bsz * num_heads, src_len, self.head_dim
key_padding_mask: bsz, src_len
attn_mask: tgt_len, src_len
"""
bsz, tgt_len, embed_dim = q_proj.size()
bsz = bsz // self.num_heads
src_len = k_proj.size(1)
attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias
attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
attn_energy = attn_energy.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
float("-inf"),
)
return attn_energy
def expected_alignment_train(self, p_choose, key_padding_mask):
"""
Calculating expected alignment for MMA
Mask is not need because p_choose will be 0 if masked
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
a_ij = p_ij q_ij
parellel solution:
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
"""
# p_choose: bsz * num_heads, tgt_len, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# cumprod_1mp : bsz * num_heads, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)
init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
init_attention[:, :, 0] = 1.0
previous_attn = [init_attention]
for i in range(tgt_len):
# p_choose: bsz * num_heads, tgt_len, src_len
# cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
# previous_attn[i]: bsz * num_heads, 1, src_len
# alpha_i: bsz * num_heads, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
).clamp(0, 1.0)
previous_attn.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_attn[1:], dim=1)
if self.mass_preservation:
# Last token has the residual probabilities
alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
assert not torch.isnan(alpha).any(), "NaN detected in alpha."
return alpha
def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state):
"""
Calculating mo alignment for MMA during inference time
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
key_padding_mask: bsz * src_len
incremental_state: dict
"""
# p_choose: bsz * self.num_heads, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# One token at a time
assert tgt_len == 1
p_choose = p_choose[:, 0, :]
monotonic_cache = self._get_monotonic_buffer(incremental_state)
# prev_monotonic_step: bsz, num_heads
bsz = bsz_num_heads // self.num_heads
prev_monotonic_step = monotonic_cache.get(
"step", p_choose.new_zeros([bsz, self.num_heads]).long()
)
bsz, num_heads = prev_monotonic_step.size()
assert num_heads == self.num_heads
assert bsz * num_heads == bsz_num_heads
# p_choose: bsz, num_heads, src_len
p_choose = p_choose.view(bsz, num_heads, src_len)
if key_padding_mask is not None:
src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long()
else:
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
# src_lengths: bsz, num_heads
src_lengths = src_lengths.expand_as(prev_monotonic_step)
# new_monotonic_step: bsz, num_heads
new_monotonic_step = prev_monotonic_step
step_offset = 0
if key_padding_mask is not None:
if key_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
# finish_read: bsz, num_heads
finish_read = new_monotonic_step.eq(max_steps)
while finish_read.sum().item() < bsz * self.num_heads:
# p_choose: bsz * self.num_heads, src_len
# only choose the p at monotonic steps
# p_choose_i: bsz , self.num_heads
p_choose_i = (
p_choose.gather(
2,
(step_offset + new_monotonic_step)
.unsqueeze(2)
.clamp(0, src_len - 1),
)
).squeeze(2)
action = (
(p_choose_i < 0.5)
.type_as(prev_monotonic_step)
.masked_fill(finish_read, 0)
)
# 1 x bsz
# sample actions on unfinished seq
# 1 means stay, finish reading
# 0 means leave, continue reading
# dist = torch.distributions.bernoulli.Bernoulli(p_choose)
# action = dist.sample().type_as(finish_read) * (1 - finish_read)
new_monotonic_step += action
finish_read = new_monotonic_step.eq(max_steps) | (action == 0)
# finish_read = (~ (finish_read.sum(dim=1, keepdim=True) < self.num_heads / 2)) | finish_read
monotonic_cache["step"] = new_monotonic_step
# alpha: bsz * num_heads, 1, src_len
# new_monotonic_step: bsz, num_heads
alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter(
1,
(step_offset + new_monotonic_step)
.view(bsz * self.num_heads, 1)
.clamp(0, src_len - 1),
1,
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0
)
alpha = alpha.unsqueeze(1)
self._set_monotonic_buffer(incremental_state, monotonic_cache)
return alpha
def v_proj_output(self, value):
raise NotImplementedError
def forward(
self,
query,
key,
value,
key_padding_mask=None,
incremental_state=None,
*args,
**kwargs,
):
tgt_len, bsz, embed_dim = query.size()
src_len = value.size(0)
# stepwise prob
# p_choose: bsz * self.num_heads, tgt_len, src_len
p_choose = self.p_choose(query, key, key_padding_mask)
# expected alignment alpha
# bsz * self.num_heads, tgt_len, src_len
if incremental_state is not None:
alpha = self.expected_alignment_infer(
p_choose, key_padding_mask, incremental_state
)
else:
alpha = self.expected_alignment_train(p_choose, key_padding_mask)
# expected attention beta
# bsz * self.num_heads, tgt_len, src_len
beta = self.expected_attention(
alpha, query, key, value, key_padding_mask, incremental_state
)
attn_weights = beta
v_proj = self.v_proj_output(value)
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
return attn, {"alpha": alpha, "beta": beta, "p_choose": p_choose}
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(incremental_state, new_order)
input_buffer = self._get_monotonic_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(0, new_order)
self._set_monotonic_buffer(incremental_state, input_buffer)
def _get_monotonic_buffer(self, incremental_state):
return (
utils.get_incremental_state(
self,
incremental_state,
"monotonic",
)
or {}
)
def _set_monotonic_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
"monotonic",
buffer,
)
def get_pointer(self, incremental_state):
return (
utils.get_incremental_state(
self,
incremental_state,
"monotonic",
)
or {}
)
def get_fastest_pointer(self, incremental_state):
return self.get_pointer(incremental_state)["step"].max(0)[0]
def set_pointer(self, incremental_state, p_choose):
curr_pointer = self.get_pointer(incremental_state)
if len(curr_pointer) == 0:
buffer = torch.zeros_like(p_choose)
else:
buffer = self.get_pointer(incremental_state)["step"]
buffer += (p_choose < 0.5).type_as(buffer)
utils.set_incremental_state(
self,
incremental_state,
"monotonic",
{"step": buffer},
)
@register_monotonic_attention("hard_aligned")
class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
def __init__(self, args):
MultiheadAttention.__init__(
self,
embed_dim=args.decoder_embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
MonotonicAttention.__init__(self, args)
self.k_in_proj = {"monotonic": self.k_proj}
self.q_in_proj = {"monotonic": self.q_proj}
self.v_in_proj = {"output": self.v_proj}
def input_projections(self, query, key, value, name):
"""
Prepare inputs for multihead attention
============================================================
Expected input size
query: tgt_len, bsz, embed_dim
key: src_len, bsz, embed_dim
value: src_len, bsz, embed_dim
name: monotonic or soft
"""
if query is not None:
bsz = query.size(1)
q = self.q_in_proj[name](query)
q *= self.scaling
q = (
q.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else:
q = None
if key is not None:
bsz = key.size(1)
k = self.k_in_proj[name](key)
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else:
k = None
if value is not None:
bsz = value.size(1)
v = self.v_in_proj[name](value)
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else:
v = None
return q, k, v
def p_choose(self, query, key, key_padding_mask=None):
"""
Calculating step wise prob for reading and writing
1 to read, 0 to write
============================================================
Expected input size
query: bsz, tgt_len, embed_dim
key: bsz, src_len, embed_dim
value: bsz, src_len, embed_dim
key_padding_mask: bsz, src_len
attn_mask: bsz, src_len
query: bsz, tgt_len, embed_dim
"""
# prepare inputs
q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic")
# attention energy
attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask)
noise = 0
if self.training:
# add noise here to encourage discretness
noise = (
torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
.type_as(attn_energy)
.to(attn_energy.device)
)
p_choose = torch.sigmoid(attn_energy + noise)
_, _, tgt_len, src_len = p_choose.size()
# p_choose: bsz * self.num_heads, tgt_len, src_len
return p_choose.view(-1, tgt_len, src_len)
def expected_attention(self, alpha, *args):
"""
For MMA-H, beta = alpha
"""
return alpha
def v_proj_output(self, value):
_, _, v_proj = self.input_projections(None, None, value, "output")
return v_proj
@register_monotonic_attention("infinite_lookback")
class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHard):
def __init__(self, args):
super().__init__(args)
self.init_soft_attention()
def init_soft_attention(self):
self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True)
self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.k_in_proj["soft"] = self.k_proj_soft
self.q_in_proj["soft"] = self.q_proj_soft
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
nn.init.xavier_uniform_(
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
else:
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
def expected_attention(
self, alpha, query, key, value, key_padding_mask, incremental_state
):
# monotonic attention, we will calculate milk here
bsz_x_num_heads, tgt_len, src_len = alpha.size()
bsz = int(bsz_x_num_heads / self.num_heads)
q, k, _ = self.input_projections(query, key, None, "soft")
soft_energy = self.attn_energy(q, k, key_padding_mask)
assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len]
soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len)
if incremental_state is not None:
monotonic_cache = self._get_monotonic_buffer(incremental_state)
monotonic_step = monotonic_cache["step"] + 1
step_offset = 0
if key_padding_mask is not None:
if key_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
monotonic_step += step_offset
mask = lengths_to_mask(
monotonic_step.view(-1), soft_energy.size(2), 1
).unsqueeze(1)
soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy)
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2)
else:
# bsz * num_heads, tgt_len, src_len
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy)
exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2)
if key_padding_mask is not None:
if key_padding_mask.any():
exp_soft_energy_cumsum = (
exp_soft_energy_cumsum.view(
-1, self.num_heads, tgt_len, src_len
)
.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
)
.view(-1, tgt_len, src_len)
)
inner_items = alpha / exp_soft_energy_cumsum
beta = exp_soft_energy * torch.cumsum(
inner_items.flip(dims=[2]), dim=2
).flip(dims=[2])
beta = self.dropout_module(beta)
assert not torch.isnan(beta).any(), "NaN detected in beta."
return beta
@register_monotonic_attention("waitk")
class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookback):
def __init__(self, args):
super().__init__(args)
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
self.waitk_lagging = args.waitk_lagging
assert (
self.waitk_lagging > 0
), f"Lagging has to been larger than 0, get {self.waitk_lagging}."
@staticmethod
def add_args(parser):
super(
MonotonicMultiheadAttentionWaitk,
MonotonicMultiheadAttentionWaitk,
).add_args(parser)
parser.add_argument(
"--waitk-lagging", type=int, required=True, help="Wait k lagging"
)
def p_choose(
self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None
):
"""
query: bsz, tgt_len
key: bsz, src_len
key_padding_mask: bsz, src_len
"""
src_len, bsz, _ = key.size()
tgt_len, bsz, _ = query.size()
p_choose = query.new_ones(bsz, tgt_len, src_len)
p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1)
p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1)
if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
# Left pad source
# add -1 to the end
p_choose = p_choose.masked_fill(
key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1
)
p_choose = convert_padding_direction(
p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True
)
p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
# remove -1
p_choose[p_choose.eq(-1)] = 0
# Extend to each head
p_choose = (
p_choose.contiguous()
.unsqueeze(1)
.expand(-1, self.num_heads, -1, -1)
.contiguous()
.view(-1, tgt_len, src_len)
)
return p_choose
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer
from . import build_monotonic_attention
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
def forward(self, x, encoder_padding_mask):
seq_len, _, _ = x.size()
attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
return super().forward(x, encoder_padding_mask, attn_mask)
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__(
args,
no_encoder_attn=True,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.encoder_attn = build_monotonic_attention(args)
self.encoder_attn_layer_norm = LayerNorm(
self.embed_dim, export=getattr(args, "char_inputs", False)
)
def prune_incremental_state(self, incremental_state):
def prune(module):
input_buffer = module._get_input_buffer(incremental_state)
for key in ["prev_key", "prev_value"]:
if input_buffer[key].size(2) > 1:
input_buffer[key] = input_buffer[key][:, :, :-1, :]
else:
input_buffer = {}
break
module._set_input_buffer(incremental_state, input_buffer)
prune(self.self_attn)
def get_steps(self, incremental_state):
return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the criterions/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("examples.simultaneous_translation.utils." + module)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
"""
Implementing exclusive cumprod.
There is cumprod in pytorch, however there is no exclusive mode.
cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
"""
tensor_size = list(tensor.size())
tensor_size[dim] = 1
return_tensor = safe_cumprod(
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
dim=dim,
eps=eps,
)
if dim == 0:
return return_tensor[:-1]
elif dim == 1:
return return_tensor[:, :-1]
elif dim == 2:
return return_tensor[:, :, :-1]
else:
raise RuntimeError("Cumprod on dimension 3 and more is not implemented")
def safe_cumprod(tensor, dim: int, eps: float = 1e-10):
"""
An implementation of cumprod to prevent precision issue.
cumprod(x)
= [x1, x1x2, x1x2x3, ....]
= [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...]
= exp(cumsum(log(x)))
"""
if (tensor + eps < 0).any().item():
raise RuntimeError(
"Safe cumprod can only take non-negative tensors as input."
"Consider use torch.cumprod if you want to calculate negative values."
)
log_tensor = torch.log(tensor + eps)
cumsum_log_tensor = torch.cumsum(log_tensor, dim)
exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor)
return exp_cumsum_log_tensor
def lengths_to_mask(lengths, max_len: int, dim: int = 0, negative_mask: bool = False):
"""
Convert a tensor of lengths to mask
For example, lengths = [[2, 3, 4]], max_len = 5
mask =
[[1, 1, 1],
[1, 1, 1],
[0, 1, 1],
[0, 0, 1],
[0, 0, 0]]
"""
assert len(lengths.size()) <= 2
if len(lengths) == 2:
if dim == 1:
lengths = lengths.t()
lengths = lengths
else:
lengths = lengths.unsqueeze(1)
# lengths : batch_size, 1
lengths = lengths.view(-1, 1)
batch_size = lengths.size(0)
# batch_size, max_len
mask = torch.arange(max_len).expand(batch_size, max_len).type_as(lengths) < lengths
if negative_mask:
mask = ~mask
if dim == 0:
# max_len, batch_size
mask = mask.t()
return mask
def moving_sum(x, start_idx: int, end_idx: int):
"""
From MONOTONIC CHUNKWISE ATTENTION
https://arxiv.org/pdf/1712.05382.pdf
Equation (18)
x = [x_1, x_2, ..., x_N]
MovingSum(x, start_idx, end_idx)_n = Sigma_{m=n−(start_idx−1)}^{n+end_idx-1} x_m
for n in {1, 2, 3, ..., N}
x : src_len, batch_size
start_idx : start idx
end_idx : end idx
Example
src_len = 5
batch_size = 3
x =
[[ 0, 5, 10],
[ 1, 6, 11],
[ 2, 7, 12],
[ 3, 8, 13],
[ 4, 9, 14]]
MovingSum(x, 3, 1) =
[[ 0, 5, 10],
[ 1, 11, 21],
[ 3, 18, 33],
[ 6, 21, 36],
[ 9, 24, 39]]
MovingSum(x, 1, 3) =
[[ 3, 18, 33],
[ 6, 21, 36],
[ 9, 24, 39],
[ 7, 17, 27],
[ 4, 9, 14]]
"""
assert start_idx > 0 and end_idx > 0
assert len(x.size()) == 2
src_len, batch_size = x.size()
# batch_size, 1, src_len
x = x.t().unsqueeze(1)
# batch_size, 1, src_len
moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
moving_sum = (
torch.nn.functional.conv1d(
x, moving_sum_weight, padding=start_idx + end_idx - 1
)
.squeeze(1)
.t()
)
moving_sum = moving_sum[end_idx:-start_idx]
assert src_len == moving_sum.size(0)
assert batch_size == moving_sum.size(1)
return moving_sum
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
class LatencyMetric(object):
@staticmethod
def length_from_padding_mask(padding_mask, batch_first: bool = False):
dim = 1 if batch_first else 0
return padding_mask.size(dim) - padding_mask.sum(dim=dim, keepdim=True)
def prepare_latency_metric(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = False,
start_from_zero: bool = True,
):
assert len(delays.size()) == 2
assert len(src_lens.size()) == 2
if start_from_zero:
delays = delays + 1
if batch_first:
# convert to batch_last
delays = delays.t()
src_lens = src_lens.t()
tgt_len, bsz = delays.size()
_, bsz_1 = src_lens.size()
if target_padding_mask is not None:
target_padding_mask = target_padding_mask.t()
tgt_len_1, bsz_2 = target_padding_mask.size()
assert tgt_len == tgt_len_1
assert bsz == bsz_2
assert bsz == bsz_1
if target_padding_mask is None:
tgt_lens = tgt_len * delays.new_ones([1, bsz]).float()
else:
# 1, batch_size
tgt_lens = self.length_from_padding_mask(target_padding_mask, False).float()
delays = delays.masked_fill(target_padding_mask, 0)
return delays, src_lens, tgt_lens, target_padding_mask
def __call__(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = False,
start_from_zero: bool = True,
):
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
delays, src_lens, target_padding_mask, batch_first, start_from_zero
)
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
"""
Expected sizes:
delays: tgt_len, batch_size
src_lens: 1, batch_size
target_padding_mask: tgt_len, batch_size
"""
raise NotImplementedError
class AverageProportion(LatencyMetric):
"""
Function to calculate Average Proportion from
Can neural machine translation do simultaneous translation?
(https://arxiv.org/abs/1606.02012)
Delays are monotonic steps, range from 1 to src_len.
Give src x tgt y, AP is calculated as:
AP = 1 / (|x||y]) sum_i^|Y| deleys_i
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
if target_padding_mask is not None:
AP = torch.sum(
delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True
)
else:
AP = torch.sum(delays, dim=0, keepdim=True)
AP = AP / (src_lens * tgt_lens)
return AP
class AverageLagging(LatencyMetric):
"""
Function to calculate Average Lagging from
STACL: Simultaneous Translation with Implicit Anticipation
and Controllable Latency using Prefix-to-Prefix Framework
(https://arxiv.org/abs/1810.08398)
Delays are monotonic steps, range from 1 to src_len.
Give src x tgt y, AP is calculated as:
AL = 1 / tau sum_i^tau delays_i - (i - 1) / gamma
Where
gamma = |y| / |x|
tau = argmin_i(delays_i = |x|)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
# tau = argmin_i(delays_i = |x|)
tgt_len, bsz = delays.size()
lagging_padding_mask = delays >= src_lens
lagging_padding_mask = torch.nn.functional.pad(
lagging_padding_mask.t(), (1, 0)
).t()[:-1, :]
gamma = tgt_lens / src_lens
lagging = (
delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
lagging.masked_fill_(lagging_padding_mask, 0)
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
AL = lagging.sum(dim=0, keepdim=True) / tau
return AL
class DifferentiableAverageLagging(LatencyMetric):
"""
Function to calculate Differentiable Average Lagging from
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
(https://arxiv.org/abs/1906.05218)
Delays are monotonic steps, range from 0 to src_len-1.
(In the original paper thery are from 1 to src_len)
Give src x tgt y, AP is calculated as:
DAL = 1 / |Y| sum_i^|Y| delays'_i - (i - 1) / gamma
Where
delays'_i =
1. delays_i if i == 1
2. max(delays_i, delays'_{i-1} + 1 / gamma)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
tgt_len, bsz = delays.size()
gamma = tgt_lens / src_lens
new_delays = torch.zeros_like(delays)
for i in range(delays.size(0)):
if i == 0:
new_delays[i] = delays[i]
else:
new_delays[i] = torch.cat(
[
new_delays[i - 1].unsqueeze(0) + 1 / gamma,
delays[i].unsqueeze(0),
],
dim=0,
).max(dim=0)[0]
DAL = (
new_delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
if target_padding_mask is not None:
DAL = DAL.masked_fill(target_padding_mask, 0)
DAL = DAL.sum(dim=0, keepdim=True) / tgt_lens
return DAL
class LatencyMetricVariance(LatencyMetric):
def prepare_latency_metric(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = True,
start_from_zero: bool = True,
):
assert batch_first
assert len(delays.size()) == 3
assert len(src_lens.size()) == 2
if start_from_zero:
delays = delays + 1
# convert to batch_last
bsz, num_heads_x_layers, tgt_len = delays.size()
bsz_1, _ = src_lens.size()
assert bsz == bsz_1
if target_padding_mask is not None:
bsz_2, tgt_len_1 = target_padding_mask.size()
assert tgt_len == tgt_len_1
assert bsz == bsz_2
if target_padding_mask is None:
tgt_lens = tgt_len * delays.new_ones([bsz, tgt_len]).float()
else:
# batch_size, 1
tgt_lens = self.length_from_padding_mask(target_padding_mask, True).float()
delays = delays.masked_fill(target_padding_mask.unsqueeze(1), 0)
return delays, src_lens, tgt_lens, target_padding_mask
class VarianceDelay(LatencyMetricVariance):
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
"""
delays : bsz, num_heads_x_layers, tgt_len
src_lens : bsz, 1
target_lens : bsz, 1
target_padding_mask: bsz, tgt_len or None
"""
if delays.size(1) == 1:
return delays.new_zeros([1])
variance_delays = delays.var(dim=1)
if target_padding_mask is not None:
variance_delays.masked_fill_(target_padding_mask, 0)
return variance_delays.sum(dim=1, keepdim=True) / tgt_lens
class LatencyInference(object):
def __init__(self, start_from_zero=True):
self.metric_calculator = {
"differentiable_average_lagging": DifferentiableAverageLagging(),
"average_lagging": AverageLagging(),
"average_proportion": AverageProportion(),
}
self.start_from_zero = start_from_zero
def __call__(self, monotonic_step, src_lens):
"""
monotonic_step range from 0 to src_len. src_len means eos
delays: bsz, tgt_len
src_lens: bsz, 1
"""
if not self.start_from_zero:
monotonic_step -= 1
src_lens = src_lens
delays = monotonic_step.view(
monotonic_step.size(0), -1, monotonic_step.size(-1)
).max(dim=1)[0]
delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as(
delays
).masked_fill(delays < src_lens, 0)
return_dict = {}
for key, func in self.metric_calculator.items():
return_dict[key] = func(
delays.float(),
src_lens.float(),
target_padding_mask=None,
batch_first=True,
start_from_zero=True,
).t()
return return_dict
class LatencyTraining(object):
def __init__(
self,
avg_weight,
var_weight,
avg_type,
var_type,
stay_on_last_token,
average_method,
):
self.avg_weight = avg_weight
self.var_weight = var_weight
self.avg_type = avg_type
self.var_type = var_type
self.stay_on_last_token = stay_on_last_token
self.average_method = average_method
self.metric_calculator = {
"differentiable_average_lagging": DifferentiableAverageLagging(),
"average_lagging": AverageLagging(),
"average_proportion": AverageProportion(),
}
self.variance_calculator = {
"variance_delay": VarianceDelay(),
}
def expected_delays_from_attention(
self, attention, source_padding_mask=None, target_padding_mask=None
):
if type(attention) == list:
# bsz, num_heads, tgt_len, src_len
bsz, num_heads, tgt_len, src_len = attention[0].size()
attention = torch.cat(attention, dim=1)
bsz, num_heads_x_layers, tgt_len, src_len = attention.size()
# bsz * num_heads * num_layers, tgt_len, src_len
attention = attention.view(-1, tgt_len, src_len)
else:
# bsz * num_heads * num_layers, tgt_len, src_len
bsz, tgt_len, src_len = attention.size()
num_heads_x_layers = 1
attention = attention.view(-1, tgt_len, src_len)
if not self.stay_on_last_token:
residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2)
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
steps = (
torch.arange(1, 1 + src_len)
.unsqueeze(0)
.unsqueeze(1)
.expand_as(attention)
.type_as(attention)
)
if source_padding_mask is not None:
src_offset = (
source_padding_mask.type_as(attention)
.sum(dim=1, keepdim=True)
.expand(bsz, num_heads_x_layers)
.contiguous()
.view(-1, 1)
)
src_lens = src_len - src_offset
if source_padding_mask[:, 0].any():
# Pad left
src_offset = src_offset.view(-1, 1, 1)
steps = steps - src_offset
steps = steps.masked_fill(steps <= 0, 0)
else:
src_lens = attention.new_ones([bsz, num_heads_x_layers]) * src_len
src_lens = src_lens.view(-1, 1)
# bsz * num_heads_num_layers, tgt_len, src_len
expected_delays = (
(steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len)
)
if target_padding_mask is not None:
expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0)
return expected_delays, src_lens
def avg_loss(self, expected_delays, src_lens, target_padding_mask):
bsz, num_heads_x_layers, tgt_len = expected_delays.size()
target_padding_mask = (
target_padding_mask.unsqueeze(1)
.expand_as(expected_delays)
.contiguous()
.view(-1, tgt_len)
)
if self.average_method == "average":
# bsz * tgt_len
expected_delays = expected_delays.mean(dim=1)
elif self.average_method == "weighted_average":
weights = torch.nn.functional.softmax(expected_delays, dim=1)
expected_delays = torch.sum(expected_delays * weights, dim=1)
elif self.average_method == "max":
# bsz * num_heads_x_num_layers, tgt_len
expected_delays = expected_delays.max(dim=1)[0]
else:
raise RuntimeError(f"{self.average_method} is not supported")
src_lens = src_lens.view(bsz, -1)[:, :1]
target_padding_mask = target_padding_mask.view(bsz, -1, tgt_len)[:, 0]
if self.avg_weight > 0.0:
if self.avg_type in self.metric_calculator:
average_delays = self.metric_calculator[self.avg_type](
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.avg_type} is not supported.")
# bsz * num_heads_x_num_layers, 1
return self.avg_weight * average_delays.sum()
else:
return 0.0
def var_loss(self, expected_delays, src_lens, target_padding_mask):
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[
:, :1
]
if self.var_weight > 0.0:
if self.var_type in self.variance_calculator:
variance_delays = self.variance_calculator[self.var_type](
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.var_type} is not supported.")
return self.var_weight * variance_delays.sum()
else:
return 0.0
def loss(self, attention, source_padding_mask=None, target_padding_mask=None):
expected_delays, src_lens = self.expected_delays_from_attention(
attention, source_padding_mask, target_padding_mask
)
latency_loss = 0
latency_loss += self.avg_loss(expected_delays, src_lens, target_padding_mask)
latency_loss += self.var_loss(expected_delays, src_lens, target_padding_mask)
return latency_loss
# Speech Recognition
`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
## Additional dependencies
On top of main fairseq dependencies there are couple more additional requirements.
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
## Preparing librispeech data
```
./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
```
## Training librispeech data
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
```
## Inference for librispeech
`$SET` can be `test_clean` or `test_other`
Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
```
## Inference for librispeech
```
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
```
`Sum/Avg` row from first table of the report has WER
## Using wav2letter components
[wav2letter](https://github.com/facebookresearch/wav2letter) now has integration with fairseq. Currently this includes:
* AutoSegmentationCriterion (ASG)
* wav2letter-style Conv/GLU model
* wav2letter's beam search decoder
To use these, follow the instructions on [this page](https://github.com/facebookresearch/wav2letter/tree/master/bindings/python) to install python bindings. Please note that python bindings are for a *subset* of wav2letter and don't require its full dependencies (notably, `flashlight` and `ArrayFire` are *not* required).
To quickly summarize the instructions: first, install [CUDA](https://developer.nvidia.com/cuda-downloads). Then follow these steps:
```
# additional prerequisites - use equivalents for your distro
sudo apt-get install build-essential cmake libatlas-base-dev libfftw3-dev liblzma-dev libbz2-dev libzstd-dev
# install KenLM from source
git clone https://github.com/kpu/kenlm.git
cd kenlm
mkdir -p build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON
make -j16
cd ..
export KENLM_ROOT_DIR=$(pwd)
cd ..
# install wav2letter python bindings
git clone https://github.com/facebookresearch/wav2letter.git
cd wav2letter/bindings/python
# make sure your python environment is active at this point
pip install torch packaging
pip install -e .
# try some examples to verify installation succeeded
python ./examples/criterion_example.py
python ./examples/decoder_example.py ../../src/decoder/test
python ./examples/feature_example.py ../../src/feature/test/data
```
## Training librispeech data (wav2letter style, Conv/GLU + ASG loss)
Training command:
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
```
Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
## Inference for librispeech (wav2letter decoder, n-gram LM)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a wav2letter-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
```
doorbell D O 1 R B E L 1 ▁
```
For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
```
doorbell ▁DOOR BE LL
doorbell ▁DOOR B E LL
doorbell ▁DO OR BE LL
doorbell ▁DOOR B EL L
doorbell ▁DOOR BE L L
doorbell ▁DO OR B E LL
doorbell ▁DOOR B E L L
doorbell ▁DO OR B EL L
doorbell ▁DO O R BE LL
doorbell ▁DO OR BE L L
```
Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
## Inference for librispeech (wav2letter decoder, viterbi only)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
from . import criterions, models, tasks # noqa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from examples.speech_recognition.data.replabels import pack_replabels
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("asg_loss")
class ASGCriterion(FairseqCriterion):
@staticmethod
def add_args(parser):
group = parser.add_argument_group("ASG Loss")
group.add_argument(
"--asg-transitions-init",
help="initial diagonal value of transition matrix",
type=float,
default=0.0,
)
group.add_argument(
"--max-replabel", help="maximum # of replabels", type=int, default=2
)
group.add_argument(
"--linseg-updates",
help="# of training updates to use LinSeg initialization",
type=int,
default=0,
)
group.add_argument(
"--hide-linseg-messages",
help="hide messages about LinSeg initialization",
action="store_true",
)
def __init__(
self,
task,
silence_token,
asg_transitions_init,
max_replabel,
linseg_updates,
hide_linseg_messages,
):
from wav2letter.criterion import ASGLoss, CriterionScaleMode
super().__init__(task)
self.tgt_dict = task.target_dictionary
self.eos = self.tgt_dict.eos()
self.silence = (
self.tgt_dict.index(silence_token)
if silence_token in self.tgt_dict
else None
)
self.max_replabel = max_replabel
num_labels = len(self.tgt_dict)
self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
self.asg.trans = torch.nn.Parameter(
asg_transitions_init * torch.eye(num_labels), requires_grad=True
)
self.linseg_progress = torch.nn.Parameter(
torch.tensor([0], dtype=torch.int), requires_grad=False
)
self.linseg_maximum = linseg_updates
self.linseg_message_state = "none" if hide_linseg_messages else "start"
@classmethod
def build_criterion(cls, args, task):
return cls(
task,
args.silence_token,
args.asg_transitions_init,
args.max_replabel,
args.linseg_updates,
args.hide_linseg_messages,
)
def linseg_step(self):
if not self.training:
return False
if self.linseg_progress.item() < self.linseg_maximum:
if self.linseg_message_state == "start":
print("| using LinSeg to initialize ASG")
self.linseg_message_state = "finish"
self.linseg_progress.add_(1)
return True
elif self.linseg_message_state == "finish":
print("| finished LinSeg initialization")
self.linseg_message_state = "none"
return False
def replace_eos_with_silence(self, tgt):
if tgt[-1] != self.eos:
return tgt
elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
return tgt[:-1]
else:
return tgt[:-1] + [self.silence]
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
B = emissions.size(0)
T = emissions.size(1)
device = emissions.device
target = torch.IntTensor(B, T)
target_size = torch.IntTensor(B)
using_linseg = self.linseg_step()
for b in range(B):
initial_target_size = sample["target_lengths"][b].item()
if initial_target_size == 0:
raise ValueError("target size cannot be zero")
tgt = sample["target"][b, :initial_target_size].tolist()
tgt = self.replace_eos_with_silence(tgt)
tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
tgt = tgt[:T]
if using_linseg:
tgt = [tgt[t * len(tgt) // T] for t in range(T)]
target[b][: len(tgt)] = torch.IntTensor(tgt)
target_size[b] = len(tgt)
loss = self.asg.forward(emissions, target.to(device), target_size.to(device))
if reduce:
loss = torch.sum(loss)
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / nsentences,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return agg_output
import importlib
import os
# ASG loss requires wav2letter
files_to_skip = set()
try:
import wav2letter
except ImportError:
files_to_skip.add("ASG_loss.py")
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip:
criterion_name = file[: file.find(".py")]
importlib.import_module(
"examples.speech_recognition.criterions." + criterion_name
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("cross_entropy_acc")
class CrossEntropyWithAccCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
def compute_loss(self, model, net_output, target, reduction, log_probs):
# N, T -> N * T
target = target.view(-1)
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
if not hasattr(lprobs, "batch_first"):
logging.warning(
"ERROR: we need to know whether "
"batch first for the net output; "
"you need to set batch_first attribute for the return value of "
"model.get_normalized_probs. Now, we assume this is true, but "
"in the future, we will raise exception instead. "
)
batch_first = getattr(lprobs, "batch_first", True)
if not batch_first:
lprobs = lprobs.transpose(0, 1)
# N, T, D -> N * T, D
lprobs = lprobs.view(-1, lprobs.size(-1))
loss = F.nll_loss(
lprobs, target, ignore_index=self.padding_idx, reduction=reduction
)
return lprobs, loss
def get_logging_output(self, sample, target, lprobs, loss):
target = target.view(-1)
mask = target != self.padding_idx
correct = torch.sum(
lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
)
total = torch.sum(mask)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
"correct": utils.item(correct.data),
"total": utils.item(total.data),
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
}
return sample_size, logging_output
def forward(self, model, sample, reduction="sum", log_probs=True):
"""Computes the cross entropy with accuracy metric for the given sample.
This is similar to CrossEntropyCriterion in fairseq, but also
computes accuracy metrics as part of logging
Args:
logprobs (Torch.tensor) of shape N, T, D i.e.
batchsize, timesteps, dimensions
targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
Returns:
tuple: With three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
TODO:
* Currently this Criterion will only work with LSTMEncoderModels or
FairseqModels which have decoder, or Models which return TorchTensor
as net_output.
We need to make a change to support all FairseqEncoder models.
"""
net_output = model(**sample["net_input"])
target = model.get_targets(sample, net_output)
lprobs, loss = self.compute_loss(
model, net_output, target, reduction, log_probs
)
sample_size, logging_output = self.get_logging_output(
sample, target, lprobs, loss
)
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
total_sum = sum(log.get("total", 0) for log in logging_outputs)
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
# if args.sentence_avg, then sample_size is nsentences, then loss
# is per-sentence loss; else sample_size is ntokens, the loss
# becomes per-output token loss
"ntokens": ntokens,
"nsentences": nsentences,
"nframes": nframes,
"sample_size": sample_size,
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
"correct": correct_sum,
"total": total_sum,
# total is the number of validate tokens
}
if sample_size != ntokens:
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
# loss: per output token loss
# nll_loss: per sentence loss
return agg_output
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .asr_dataset import AsrDataset
__all__ = [
"AsrDataset",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
from fairseq.data import FairseqDataset
from . import data_utils
from .collaters import Seq2SeqCollater
class AsrDataset(FairseqDataset):
"""
A dataset representing speech and corresponding transcription.
Args:
aud_paths: (List[str]): A list of str with paths to audio files.
aud_durations_ms (List[int]): A list of int containing the durations of
audio files.
tgt (List[torch.LongTensor]): A list of LongTensors containing the indices
of target transcriptions.
tgt_dict (~fairseq.data.Dictionary): target vocabulary.
ids (List[str]): A list of utterance IDs.
speakers (List[str]): A list of speakers corresponding to utterances.
num_mel_bins (int): Number of triangular mel-frequency bins (default: 80)
frame_length (float): Frame length in milliseconds (default: 25.0)
frame_shift (float): Frame shift in milliseconds (default: 10.0)
"""
def __init__(
self,
aud_paths,
aud_durations_ms,
tgt,
tgt_dict,
ids,
speakers,
num_mel_bins=80,
frame_length=25.0,
frame_shift=10.0,
):
assert frame_length > 0
assert frame_shift > 0
assert all(x > frame_length for x in aud_durations_ms)
self.frame_sizes = [
int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
]
assert len(aud_paths) > 0
assert len(aud_paths) == len(aud_durations_ms)
assert len(aud_paths) == len(tgt)
assert len(aud_paths) == len(ids)
assert len(aud_paths) == len(speakers)
self.aud_paths = aud_paths
self.tgt_dict = tgt_dict
self.tgt = tgt
self.ids = ids
self.speakers = speakers
self.num_mel_bins = num_mel_bins
self.frame_length = frame_length
self.frame_shift = frame_shift
self.s2s_collater = Seq2SeqCollater(
0,
1,
pad_index=self.tgt_dict.pad(),
eos_index=self.tgt_dict.eos(),
move_eos_to_beginning=True,
)
def __getitem__(self, index):
import torchaudio
import torchaudio.compliance.kaldi as kaldi
tgt_item = self.tgt[index] if self.tgt is not None else None
path = self.aud_paths[index]
if not os.path.exists(path):
raise FileNotFoundError("Audio file not found: {}".format(path))
sound, sample_rate = torchaudio.load_wav(path)
output = kaldi.fbank(
sound,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
)
output_cmvn = data_utils.apply_mv_norm(output)
return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
def __len__(self):
return len(self.aud_paths)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[int]): sample indices to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return self.s2s_collater.collate(samples)
def num_tokens(self, index):
return self.frame_sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (
self.frame_sizes[index],
len(self.tgt[index]) if self.tgt is not None else 0,
)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self))
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This module contains collection of classes which implement
collate functionalities for various tasks.
Collaters should know what data to expect for each sample
and they should pack / collate them into batches
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import torch
from fairseq.data import data_utils as fairseq_data_utils
class Seq2SeqCollater(object):
"""
Implements collate function mainly for seq2seq tasks
This expects each sample to contain feature (src_tokens) and
targets.
This collator is also used for aligned training task.
"""
def __init__(
self,
feature_index=0,
label_index=1,
pad_index=1,
eos_index=2,
move_eos_to_beginning=True,
):
self.feature_index = feature_index
self.label_index = label_index
self.pad_index = pad_index
self.eos_index = eos_index
self.move_eos_to_beginning = move_eos_to_beginning
def _collate_frames(self, frames):
"""Convert a list of 2d frames into a padded 3d tensor
Args:
frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
len_max = max(frame.size(0) for frame in frames)
f_dim = frames[0].size(1)
res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0)
for i, v in enumerate(frames):
res[i, : v.size(0)] = v
return res
def collate(self, samples):
"""
utility function to collate samples into batch for speech recognition.
"""
if len(samples) == 0:
return {}
# parse samples into torch tensors
parsed_samples = []
for s in samples:
# skip invalid samples
if s["data"][self.feature_index] is None:
continue
source = s["data"][self.feature_index]
if isinstance(source, (np.ndarray, np.generic)):
source = torch.from_numpy(source)
target = s["data"][self.label_index]
if isinstance(target, (np.ndarray, np.generic)):
target = torch.from_numpy(target).long()
elif isinstance(target, list):
target = torch.LongTensor(target)
parsed_sample = {"id": s["id"], "source": source, "target": target}
parsed_samples.append(parsed_sample)
samples = parsed_samples
id = torch.LongTensor([s["id"] for s in samples])
frames = self._collate_frames([s["source"] for s in samples])
# sort samples by descending number of frames
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
frames_lengths, sort_order = frames_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
frames = frames.index_select(0, sort_order)
target = None
target_lengths = None
prev_output_tokens = None
if samples[0].get("target", None) is not None:
ntokens = sum(len(s["target"]) for s in samples)
target = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=False,
)
target = target.index_select(0, sort_order)
target_lengths = torch.LongTensor(
[s["target"].size(0) for s in samples]
).index_select(0, sort_order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=self.move_eos_to_beginning,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s["source"]) for s in samples)
batch = {
"id": id,
"ntokens": ntokens,
"net_input": {"src_tokens": frames, "src_lengths": frames_lengths},
"target": target,
"target_lengths": target_lengths,
"nsentences": len(samples),
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def calc_mean_invstddev(feature):
if len(feature.size()) != 2:
raise ValueError("We expect the input feature to be 2-D tensor")
mean = feature.mean(0)
var = feature.var(0)
# avoid division by ~zero
eps = 1e-8
if (var < eps).any():
return mean, 1.0 / (torch.sqrt(var) + eps)
return mean, 1.0 / torch.sqrt(var)
def apply_mv_norm(features):
# If there is less than 2 spectrograms, the variance cannot be computed (is NaN)
# and normalization is not possible, so return the item as it is
if features.size(0) < 2:
return features
mean, invstddev = calc_mean_invstddev(features)
res = (features - mean) * invstddev
return res
def lengths_to_encoder_padding_mask(lengths, batch_first=False):
"""
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor
Args:
lengths: a (B, )-shaped tensor
Return:
max_length: maximum length of B sequences
encoder_padding_mask: a (max_length, B) binary mask, where
[t, b] = 0 for t < lengths[b] and 1 otherwise
TODO:
kernelize this function if benchmarking shows this function is slow
"""
max_lengths = torch.max(lengths).item()
bsz = lengths.size(0)
encoder_padding_mask = torch.arange(
max_lengths
).to( # a (T, ) tensor with [0, ..., T-1]
lengths.device
).view( # move to the right device
1, max_lengths
).expand( # reshape to (1, T)-shaped tensor
bsz, -1
) >= lengths.view( # expand to (B, T)-shaped tensor
bsz, 1
).expand(
-1, max_lengths
)
if not batch_first:
return encoder_padding_mask.t(), max_lengths
else:
return encoder_padding_mask, max_lengths
def encoder_padding_mask_to_lengths(
encoder_padding_mask, max_lengths, batch_size, device
):
"""
convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor
Conventionally, encoder output contains a encoder_padding_mask, which is
a 2-D mask in a shape (T, B), whose (t, b) element indicate whether
encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we
need to convert this mask tensor to a 1-D tensor in shape (B, ), where
[b] denotes the valid length of b-th sequence
Args:
encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None,
indicating all are valid
Return:
seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the
number of valid elements of b-th sequence
max_lengths: maximum length of all sequence, if encoder_padding_mask is
not None, max_lengths must equal to encoder_padding_mask.size(0)
batch_size: batch size; if encoder_padding_mask is
not None, max_lengths must equal to encoder_padding_mask.size(1)
device: which device to put the result on
"""
if encoder_padding_mask is None:
return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device)
assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match"
assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match"
return max_lengths - torch.sum(encoder_padding_mask, dim=0)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Replabel transforms for use with wav2letter's ASG criterion.
"""
def replabel_symbol(i):
"""
Replabel symbols used in wav2letter, currently just "1", "2", ...
This prevents training with numeral tokens, so this might change in the future
"""
return str(i)
def pack_replabels(tokens, dictionary, max_reps):
"""
Pack a token sequence so that repeated symbols are replaced by replabels
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_value_to_idx = [0] * (max_reps + 1)
for i in range(1, max_reps + 1):
replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i))
result = []
prev_token = -1
num_reps = 0
for token in tokens:
if token == prev_token and num_reps < max_reps:
num_reps += 1
else:
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
num_reps = 0
result.append(token)
prev_token = token
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
return result
def unpack_replabels(tokens, dictionary, max_reps):
"""
Unpack a token sequence so that replabels are replaced by repeated symbols
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_idx_to_value = {}
for i in range(1, max_reps + 1):
replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i
result = []
prev_token = -1
for token in tokens:
try:
for _ in range(replabel_idx_to_value[token]):
result.append(prev_token)
prev_token = -1
except KeyError:
result.append(token)
prev_token = token
return result
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import concurrent.futures
import json
import multiprocessing
import os
from collections import namedtuple
from itertools import chain
import sentencepiece as spm
from fairseq.data import Dictionary
MILLISECONDS_TO_SECONDS = 0.001
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
import torchaudio
input = {}
output = {}
si, ei = torchaudio.info(aud_path)
input["length_ms"] = int(
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
)
input["path"] = aud_path
token = " ".join(sp.EncodeAsPieces(lable))
ids = tgt_dict.encode_line(token, append_eos=False)
output["text"] = lable
output["token"] = token
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
return {utt_id: {"input": input, "output": output}}
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio-dirs",
nargs="+",
default=["-"],
required=True,
help="input directories with audio files",
)
parser.add_argument(
"--labels",
required=True,
help="aggregated input labels with format <ID LABEL> per line",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--spm-model",
required=True,
help="sentencepiece model to use for encoding",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--dictionary",
required=True,
help="file to load fairseq dictionary from",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
parser.add_argument(
"--output",
required=True,
type=argparse.FileType("w"),
help="path to save json output",
)
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
sp.Load(args.spm_model.name)
tgt_dict = Dictionary.load(args.dictionary)
labels = {}
for line in args.labels:
(utt_id, label) = line.split(" ", 1)
labels[utt_id] = label
if len(labels) == 0:
raise Exception("No labels found in ", args.labels_path)
Sample = namedtuple("Sample", "aud_path utt_id")
samples = []
for path, _, files in chain.from_iterable(
os.walk(path) for path in args.audio_dirs
):
for f in files:
if f.endswith(args.audio_format):
if len(os.path.splitext(f)) != 2:
raise Exception("Expect <utt_id.extension> file name. Got: ", f)
utt_id = os.path.splitext(f)[0]
if utt_id not in labels:
continue
samples.append(Sample(os.path.join(path, f), utt_id))
utts = {}
num_cpu = multiprocessing.cpu_count()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
future_to_sample = {
executor.submit(
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
): s
for s in samples
}
for future in concurrent.futures.as_completed(future_to_sample):
try:
data = future.result()
except Exception as exc:
print("generated an exception: ", exc)
else:
utts.update(data)
json.dump({"utts": utts}, args.output, indent=4)
if __name__ == "__main__":
main()
#!/usr/bin/env bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Prepare librispeech dataset
base_url=www.openslr.org/resources/12
train_dir=train_960
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <download_dir> <out_dir>"
echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
exit 1
fi
download_dir=${1%/}
out_dir=${2%/}
fairseq_root=~/fairseq-py/
mkdir -p ${out_dir}
cd ${out_dir} || exit
nbpe=5000
bpemode=unigram
if [ ! -d "$fairseq_root" ]; then
echo "$0: Please set correct fairseq_root"
exit 1
fi
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
url=$base_url/$part.tar.gz
if ! wget -P $download_dir $url; then
echo "$0: wget failed for $url"
exit 1
fi
if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
exit 1
fi
done
echo "Merge all train packs into one"
mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
for part in train-clean-100 train-clean-360 train-other-500; do
mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
done
echo "Merge train text"
find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
# Use combined dev-clean and dev-other as validation set
find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
echo "dictionary: ${dict}"
echo "Dictionary preparation"
mkdir -p data/lang_char/
echo "<unk> 3" > ${dict}
echo "</s> 2" >> ${dict}
echo "<pad> 1" >> ${dict}
cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
wc -l ${dict}
echo "Prepare train and test jsons"
for part in train_960 test-other test-clean; do
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
done
# fairseq expects to find train.json and valid.json during training
mv train_960.json train.json
echo "Prepare valid json"
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
cp ${fairseq_dict} ./dict.txt
cp ${bpemodel}.model ./spm.model
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Run inference for pre-processed data with a trained model.
"""
import logging
import math
import os
import sys
import editdistance
import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import StopwatchMeter, TimeMeter
logging.basicConfig()
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def add_asr_eval_argument(parser):
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument(
"--rnnt_decoding_type",
default="greedy",
help="wfstlm on dictonary\
output units",
)
try:
parser.add_argument(
"--lm-weight",
"--lm_weight",
type=float,
default=0.2,
help="weight for lm while interpolating with neural score",
)
except:
pass
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
parser.add_argument(
"--w2l-decoder",
choices=["viterbi", "kenlm", "fairseqlm"],
help="use a w2l decoder",
)
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
parser.add_argument("--beam-threshold", type=float, default=25.0)
parser.add_argument("--beam-size-token", type=float, default=100)
parser.add_argument("--word-score", type=float, default=1.0)
parser.add_argument("--unk-weight", type=float, default=-math.inf)
parser.add_argument("--sil-weight", type=float, default=0.0)
parser.add_argument(
"--dump-emissions",
type=str,
default=None,
help="if present, dumps emissions into this file and exits",
)
parser.add_argument(
"--dump-features",
type=str,
default=None,
help="if present, dumps features into this file and exits",
)
parser.add_argument(
"--load-emissions",
type=str,
default=None,
help="if present, loads emissions from this file",
)
return parser
def check_args(args):
# assert args.path is not None, "--path required for generation!"
# assert args.results_path is not None, "--results_path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
def get_dataset_itr(args, task, models):
return task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
).next_epoch_itr(shuffle=False)
def process_predictions(
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
):
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
if "words" in hypo:
hyp_words = " ".join(hypo["words"])
else:
hyp_words = post_process(hyp_pieces, args.remove_bpe)
if res_files is not None:
print(
"{} ({}-{})".format(hyp_pieces, speaker, id),
file=res_files["hypo.units"],
)
print(
"{} ({}-{})".format(hyp_words, speaker, id),
file=res_files["hypo.words"],
)
tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = post_process(tgt_pieces, args.remove_bpe)
if res_files is not None:
print(
"{} ({}-{})".format(tgt_pieces, speaker, id),
file=res_files["ref.units"],
)
print(
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
)
# only score top hypothesis
if not args.quiet:
logger.debug("HYPO:" + hyp_words)
logger.debug("TARGET:" + tgt_words)
logger.debug("___________________")
hyp_words = hyp_words.split()
tgt_words = tgt_words.split()
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def prepare_result_files(args):
def get_res_file(file_prefix):
if args.num_shards > 1:
file_prefix = f"{args.shard_id}_{file_prefix}"
path = os.path.join(
args.results_path,
"{}-{}-{}.txt".format(
file_prefix, os.path.basename(args.path), args.gen_subset
),
)
return open(path, "w", buffering=1)
if not args.results_path:
return None
return {
"hypo.words": get_res_file("hypo.word"),
"hypo.units": get_res_file("hypo.units"),
"ref.words": get_res_file("ref.word"),
"ref.units": get_res_file("ref.units"),
}
def load_models_and_criterions(
filenames, data_path, arg_overrides=None, task=None, model_state=None
):
models = []
criterions = []
if arg_overrides is None:
arg_overrides = {}
arg_overrides["wer_args"] = None
arg_overrides["data"] = data_path
if filenames is None:
assert model_state is not None
filenames = [0]
else:
filenames = filenames.split(":")
for filename in filenames:
if model_state is None:
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides)
else:
state = model_state
args = state["args"]
if task is None:
task = tasks.setup_task(args)
model = task.build_model(args)
model.load_state_dict(state["model"], strict=True)
models.append(model)
criterion = task.build_criterion(args)
if "criterion" in state:
criterion.load_state_dict(state["criterion"], strict=True)
criterions.append(criterion)
return models, criterions, args
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
class ExistingEmissionsDecoder(object):
def __init__(self, decoder, emissions):
self.decoder = decoder
self.emissions = emissions
def generate(self, models, sample, **unused):
ids = sample["id"].cpu().numpy()
try:
emissions = np.stack(self.emissions[ids])
except:
print([x.shape for x in self.emissions[ids]])
raise Exception("invalid sizes")
emissions = torch.from_numpy(emissions)
return self.decoder.decode(emissions)
def main(args, task=None, model_state=None):
check_args(args)
if args.max_tokens is None and args.batch_size is None:
args.max_tokens = 4000000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
if task is None:
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
logger.info(
"| {} {} {} examples".format(
args.data, args.gen_subset, len(task.dataset(args.gen_subset))
)
)
# Set dictionary
tgt_dict = task.target_dictionary
logger.info("| decoding with criterion {}".format(args.criterion))
# Load ensemble
if args.load_emissions:
models, criterions = [], []
else:
logger.info("| loading model(s) from {}".format(args.path))
models, criterions, _ = load_models_and_criterions(
args.path,
data_path=args.data,
arg_overrides=eval(args.model_overrides), # noqa
task=task,
model_state=model_state,
)
optimize_models(args, use_cuda, models)
# hack to pass transitions to W2lDecoder
if args.criterion == "asg_loss":
trans = criterions[0].asg.trans.data
args.asg_transitions = torch.flatten(trans).tolist()
# Load dataset (possibly sharded)
itr = get_dataset_itr(args, task, models)
# Initialize generator
gen_timer = StopwatchMeter()
def build_generator(args):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, task.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, task.target_dictionary)
elif w2l_decoder == "fairseqlm":
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
return W2lFairseqLMDecoder(args, task.target_dictionary)
else:
print(
"only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
)
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
generator = build_generator(args)
if args.load_emissions:
generator = ExistingEmissionsDecoder(
generator, np.load(args.load_emissions, allow_pickle=True)
)
logger.info("loaded emissions from " + args.load_emissions)
num_sentences = 0
if args.results_path is not None and not os.path.exists(args.results_path):
os.makedirs(args.results_path)
max_source_pos = (
utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
),
)
if max_source_pos is not None:
max_source_pos = max_source_pos[0]
if max_source_pos is not None:
max_source_pos = max_source_pos[0] - 1
if args.dump_emissions:
emissions = {}
if args.dump_features:
features = {}
models[0].bert.proj = None
else:
res_files = prepare_result_files(args)
errs_t = 0
lengths_t = 0
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]
gen_timer.start()
if args.dump_emissions:
with torch.no_grad():
encoder_out = models[0](**sample["net_input"])
emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
emm = emm.transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
emissions[id.item()] = emm[i]
continue
elif args.dump_features:
with torch.no_grad():
encoder_out = models[0](**sample["net_input"])
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
padding = (
encoder_out["encoder_padding_mask"][i].cpu().numpy()
if encoder_out["encoder_padding_mask"] is not None
else None
)
features[id.item()] = (feat[i], padding)
continue
hypos = task.inference_step(generator, models, sample, prefix_tokens)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample["id"].tolist()):
speaker = None
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
id = sample_id
toks = (
sample["target"][i, :]
if "target_label" not in sample
else sample["target_label"][i, :]
)
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
# Process top predictions
errs, length = process_predictions(
args,
hypos[i],
None,
tgt_dict,
target_tokens,
res_files,
speaker,
id,
)
errs_t += errs
lengths_t += length
wps_meter.update(num_generated_tokens)
t.log({"wps": round(wps_meter.avg)})
num_sentences += (
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
)
wer = None
if args.dump_emissions:
emm_arr = []
for i in range(len(emissions)):
emm_arr.append(emissions[i])
np.save(args.dump_emissions, emm_arr)
logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
elif args.dump_features:
feat_arr = []
for i in range(len(features)):
feat_arr.append(features[i])
np.save(args.dump_features, feat_arr)
logger.info(f"saved {len(features)} emissions to {args.dump_features}")
else:
if lengths_t > 0:
wer = errs_t * 100.0 / lengths_t
logger.info(f"WER: {wer}")
logger.info(
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
"sentences/s, {:.2f} tokens/s)".format(
num_sentences,
gen_timer.n,
gen_timer.sum,
num_sentences / gen_timer.sum,
1.0 / gen_timer.avg,
)
)
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
return task, wer
def make_parser():
parser = options.get_generation_parser()
parser = add_asr_eval_argument(parser)
return parser
def cli_main():
parser = make_parser()
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module("examples.speech_recognition.models." + model_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import math
from collections.abc import Iterable
import torch
import torch.nn as nn
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqEncoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import (
LinearizedConvolution,
TransformerDecoderLayer,
TransformerEncoderLayer,
VGGBlock,
)
@register_model("asr_vggtransformer")
class VGGTransformerModel(FairseqEncoderDecoderModel):
"""
Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--vggblock-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one vggblock:
[(out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
use_layer_norm), ...])
""",
)
parser.add_argument(
"--transformer-enc-config",
type=str,
metavar="EXPR",
help=""""
a tuple containing the configuration of the encoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]')
""",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="""
encoder output dimension, can be None. If specified, projecting the
transformer output to the specified dimension""",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--tgt-embed-dim",
type=int,
metavar="N",
help="embedding dimension of the decoder target tokens",
)
parser.add_argument(
"--transformer-dec-config",
type=str,
metavar="EXPR",
help="""
a tuple containing the configuration of the decoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]
""",
)
parser.add_argument(
"--conv-dec-config",
type=str,
metavar="EXPR",
help="""
an array of tuples for the decoder 1-D convolution config
[(out_channels, conv_kernel_size, use_layer_norm), ...]""",
)
@classmethod
def build_encoder(cls, args, task):
return VGGTransformerEncoder(
input_feat_per_channel=args.input_feat_per_channel,
vggblock_config=eval(args.vggblock_enc_config),
transformer_config=eval(args.transformer_enc_config),
encoder_output_dim=args.enc_output_dim,
in_channels=args.in_channels,
)
@classmethod
def build_decoder(cls, args, task):
return TransformerDecoder(
dictionary=task.target_dictionary,
embed_dim=args.tgt_embed_dim,
transformer_config=eval(args.transformer_dec_config),
conv_config=eval(args.conv_dec_config),
encoder_output_dim=args.enc_output_dim,
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted
# (in case there are any new ones)
base_architecture(args)
encoder = cls.build_encoder(args, task)
decoder = cls.build_decoder(args, task)
return cls(encoder, decoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2
DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2
# 256: embedding dimension
# 4: number of heads
# 1024: FFN
# True: apply layerNorm before (dropout + resiaul) instead of after
# 0.2 (dropout): dropout after MultiheadAttention and second FC
# 0.2 (attention_dropout): dropout in MultiheadAttention
# 0.2 (relu_dropout): dropout after ReLu
DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2
DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2
# TODO: repace transformer encoder config from one liner
# to explicit args to get rid of this transformation
def prepare_transformer_encoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.encoder_embed_dim = input_dim
args.encoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.encoder_normalize_before = normalize_before
args.encoder_ffn_embed_dim = ffn_dim
return args
def prepare_transformer_decoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.decoder_embed_dim = input_dim
args.decoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.decoder_normalize_before = normalize_before
args.decoder_ffn_embed_dim = ffn_dim
return args
class VGGTransformerEncoder(FairseqEncoder):
"""VGG + Transformer encoder"""
def __init__(
self,
input_feat_per_channel,
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
encoder_output_dim=512,
in_channels=1,
transformer_context=None,
transformer_sampling=None,
):
"""constructor for VGGTransformerEncoder
Args:
- input_feat_per_channel: feature dim (not including stacked,
just base feature)
- in_channel: # input channels (e.g., if stack 8 feature vector
together, this is 8)
- vggblock_config: configuration of vggblock, see comments on
DEFAULT_ENC_VGGBLOCK_CONFIG
- transformer_config: configuration of transformer layer, see comments
on DEFAULT_ENC_TRANSFORMER_CONFIG
- encoder_output_dim: final transformer output embedding dimension
- transformer_context: (left, right) if set, self-attention will be focused
on (t-left, t+right)
- transformer_sampling: an iterable of int, must match with
len(transformer_config), transformer_sampling[i] indicates sampling
factor for i-th transformer layer, after multihead att and feedfoward
part
"""
super().__init__(None)
self.num_vggblocks = 0
if vggblock_config is not None:
if not isinstance(vggblock_config, Iterable):
raise ValueError("vggblock_config is not iterable")
self.num_vggblocks = len(vggblock_config)
self.conv_layers = nn.ModuleList()
self.in_channels = in_channels
self.input_dim = input_feat_per_channel
self.pooling_kernel_sizes = []
if vggblock_config is not None:
for _, config in enumerate(vggblock_config):
(
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
layer_norm,
) = config
self.conv_layers.append(
VGGBlock(
in_channels,
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
input_dim=input_feat_per_channel,
layer_norm=layer_norm,
)
)
self.pooling_kernel_sizes.append(pooling_kernel_size)
in_channels = out_channels
input_feat_per_channel = self.conv_layers[-1].output_dim
transformer_input_dim = self.infer_conv_output_dim(
self.in_channels, self.input_dim
)
# transformer_input_dim is the output dimension of VGG part
self.validate_transformer_config(transformer_config)
self.transformer_context = self.parse_transformer_context(transformer_context)
self.transformer_sampling = self.parse_transformer_sampling(
transformer_sampling, len(transformer_config)
)
self.transformer_layers = nn.ModuleList()
if transformer_input_dim != transformer_config[0][0]:
self.transformer_layers.append(
Linear(transformer_input_dim, transformer_config[0][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.transformer_layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[i])
)
)
self.encoder_output_dim = encoder_output_dim
self.transformer_layers.extend(
[
Linear(transformer_config[-1][0], encoder_output_dim),
LayerNorm(encoder_output_dim),
]
)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
bsz, max_seq_len, _ = src_tokens.size()
x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
x = x.transpose(1, 2).contiguous()
# (B, C, T, feat)
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
bsz, _, output_seq_len, _ = x.size()
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat)
x = x.transpose(1, 2).transpose(0, 1)
x = x.contiguous().view(output_seq_len, bsz, -1)
input_lengths = src_lengths.clone()
for s in self.pooling_kernel_sizes:
input_lengths = (input_lengths.float() / s).ceil().long()
encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
input_lengths, batch_first=True
)
if not encoder_padding_mask.any():
encoder_padding_mask = None
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
transformer_layer_idx = 0
for layer_idx in range(len(self.transformer_layers)):
if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer):
x = self.transformer_layers[layer_idx](
x, encoder_padding_mask, attn_mask
)
if self.transformer_sampling[transformer_layer_idx] != 1:
sampling_factor = self.transformer_sampling[transformer_layer_idx]
x, encoder_padding_mask, attn_mask = self.slice(
x, encoder_padding_mask, attn_mask, sampling_factor
)
transformer_layer_idx += 1
else:
x = self.transformer_layers[layer_idx](x)
# encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate
# whether encoder_output[t, b] is valid or not (valid=0, invalid=1)
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": encoder_padding_mask.t()
if encoder_padding_mask is not None
else None,
# (B, T) --> (T, B)
}
def infer_conv_output_dim(self, in_channels, input_dim):
sample_seq_len = 200
sample_bsz = 10
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
for i, _ in enumerate(self.conv_layers):
x = self.conv_layers[i](x)
x = x.transpose(1, 2)
mb, seq = x.size()[:2]
return x.contiguous().view(mb, seq, -1).size(-1)
def validate_transformer_config(self, transformer_config):
for config in transformer_config:
input_dim, num_heads = config[:2]
if input_dim % num_heads != 0:
msg = (
"ERROR in transformer config {}: ".format(config)
+ "input dimension {} ".format(input_dim)
+ "not dividable by number of heads {}".format(num_heads)
)
raise ValueError(msg)
def parse_transformer_context(self, transformer_context):
"""
transformer_context can be the following:
- None; indicates no context is used, i.e.,
transformer can access full context
- a tuple/list of two int; indicates left and right context,
any number <0 indicates infinite context
* e.g., (5, 6) indicates that for query at x_t, transformer can
access [t-5, t+6] (inclusive)
* e.g., (-1, 6) indicates that for query at x_t, transformer can
access [0, t+6] (inclusive)
"""
if transformer_context is None:
return None
if not isinstance(transformer_context, Iterable):
raise ValueError("transformer context must be Iterable if it is not None")
if len(transformer_context) != 2:
raise ValueError("transformer context must have length 2")
left_context = transformer_context[0]
if left_context < 0:
left_context = None
right_context = transformer_context[1]
if right_context < 0:
right_context = None
if left_context is None and right_context is None:
return None
return (left_context, right_context)
def parse_transformer_sampling(self, transformer_sampling, num_layers):
"""
parsing transformer sampling configuration
Args:
- transformer_sampling, accepted input:
* None, indicating no sampling
* an Iterable with int (>0) as element
- num_layers, expected number of transformer layers, must match with
the length of transformer_sampling if it is not None
Returns:
- A tuple with length num_layers
"""
if transformer_sampling is None:
return (1,) * num_layers
if not isinstance(transformer_sampling, Iterable):
raise ValueError(
"transformer_sampling must be an iterable if it is not None"
)
if len(transformer_sampling) != num_layers:
raise ValueError(
"transformer_sampling {} does not match with the number "
"of layers {}".format(transformer_sampling, num_layers)
)
for layer, value in enumerate(transformer_sampling):
if not isinstance(value, int):
raise ValueError("Invalid value in transformer_sampling: ")
if value < 1:
raise ValueError(
"{} layer's subsampling is {}.".format(layer, value)
+ " This is not allowed! "
)
return transformer_sampling
def slice(self, embedding, padding_mask, attn_mask, sampling_factor):
"""
embedding is a (T, B, D) tensor
padding_mask is a (B, T) tensor or None
attn_mask is a (T, T) tensor or None
"""
embedding = embedding[::sampling_factor, :, :]
if padding_mask is not None:
padding_mask = padding_mask[:, ::sampling_factor]
if attn_mask is not None:
attn_mask = attn_mask[::sampling_factor, ::sampling_factor]
return embedding, padding_mask, attn_mask
def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1):
"""
create attention mask according to sequence lengths and transformer
context
Args:
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is
the length of b-th sequence
- subsampling_factor: int
* Note that the left_context and right_context is specified in
the input frame-level while input to transformer may already
go through subsampling (e.g., the use of striding in vggblock)
we use subsampling_factor to scale the left/right context
Return:
- a (T, T) binary tensor or None, where T is max(input_lengths)
* if self.transformer_context is None, None
* if left_context is None,
* attn_mask[t, t + right_context + 1:] = 1
* others = 0
* if right_context is None,
* attn_mask[t, 0:t - left_context] = 1
* others = 0
* elsif
* attn_mask[t, t - left_context: t + right_context + 1] = 0
* others = 1
"""
if self.transformer_context is None:
return None
maxT = torch.max(input_lengths).item()
attn_mask = torch.zeros(maxT, maxT)
left_context = self.transformer_context[0]
right_context = self.transformer_context[1]
if left_context is not None:
left_context = math.ceil(self.transformer_context[0] / subsampling_factor)
if right_context is not None:
right_context = math.ceil(self.transformer_context[1] / subsampling_factor)
for t in range(maxT):
if left_context is not None:
st = 0
en = max(st, t - left_context)
attn_mask[t, st:en] = 1
if right_context is not None:
st = t + right_context + 1
st = min(st, maxT - 1)
attn_mask[t, st:] = 1
return attn_mask.to(input_lengths.device)
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
class TransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
"""
def __init__(
self,
dictionary,
embed_dim=512,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
conv_config=DEFAULT_DEC_CONV_CONFIG,
encoder_output_dim=512,
):
super().__init__(dictionary)
vocab_size = len(dictionary)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx)
self.conv_layers = nn.ModuleList()
for i in range(len(conv_config)):
out_channels, kernel_size, layer_norm = conv_config[i]
if i == 0:
conv_layer = LinearizedConv1d(
embed_dim, out_channels, kernel_size, padding=kernel_size - 1
)
else:
conv_layer = LinearizedConv1d(
conv_config[i - 1][0],
out_channels,
kernel_size,
padding=kernel_size - 1,
)
self.conv_layers.append(conv_layer)
if layer_norm:
self.conv_layers.append(nn.LayerNorm(out_channels))
self.conv_layers.append(nn.ReLU())
self.layers = nn.ModuleList()
if conv_config[-1][0] != transformer_config[0][0]:
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
self.layers.append(
TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.layers.append(
TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[i])
)
)
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
target_padding_mask = (
(prev_output_tokens == self.padding_idx).to(prev_output_tokens.device)
if incremental_state is None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
# embed tokens
x = self.embed_tokens(prev_output_tokens)
# B x T x C -> T x B x C
x = self._transpose_if_training(x, incremental_state)
for layer in self.conv_layers:
if isinstance(layer, LinearizedConvolution):
x = layer(x, incremental_state)
else:
x = layer(x)
# B x T x C -> T x B x C
x = self._transpose_if_inference(x, incremental_state)
# decoder layers
for layer in self.layers:
if isinstance(layer, TransformerDecoderLayer):
x, *_ = layer(
x,
(encoder_out["encoder_out"] if encoder_out is not None else None),
(
encoder_out["encoder_padding_mask"].t()
if encoder_out["encoder_padding_mask"] is not None
else None
),
incremental_state,
self_attn_mask=(
self.buffered_future_mask(x)
if incremental_state is None
else None
),
self_attn_padding_mask=(
target_padding_mask if incremental_state is None else None
),
)
else:
x = layer(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
x = self.fc_out(x)
return x, None
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def _transpose_if_training(self, x, incremental_state):
if incremental_state is None:
x = x.transpose(0, 1)
return x
def _transpose_if_inference(self, x, incremental_state):
if incremental_state:
x = x.transpose(0, 1)
return x
@register_model("asr_vggtransformer_encoder")
class VGGTransformerEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--vggblock-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one vggblock
[(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...]
""",
)
parser.add_argument(
"--transformer-enc-config",
type=str,
metavar="EXPR",
help="""
a tuple containing the configuration of the Transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ]""",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="encoder output dimension, projecting the LSTM output",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--transformer-context",
type=str,
metavar="EXPR",
help="""
either None or a tuple of two ints, indicating left/right context a
transformer can have access to""",
)
parser.add_argument(
"--transformer-sampling",
type=str,
metavar="EXPR",
help="""
either None or a tuple of ints, indicating sampling factor in each layer""",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
base_architecture_enconly(args)
encoder = VGGTransformerEncoderOnly(
vocab_size=len(task.target_dictionary),
input_feat_per_channel=args.input_feat_per_channel,
vggblock_config=eval(args.vggblock_enc_config),
transformer_config=eval(args.transformer_enc_config),
encoder_output_dim=args.enc_output_dim,
in_channels=args.in_channels,
transformer_context=eval(args.transformer_context),
transformer_sampling=eval(args.transformer_sampling),
)
return cls(encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (T, B, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
# lprobs is a (T, B, D) tensor
# we need to transoose to get (B, T, D) tensor
lprobs = lprobs.transpose(0, 1).contiguous()
lprobs.batch_first = True
return lprobs
class VGGTransformerEncoderOnly(VGGTransformerEncoder):
def __init__(
self,
vocab_size,
input_feat_per_channel,
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
encoder_output_dim=512,
in_channels=1,
transformer_context=None,
transformer_sampling=None,
):
super().__init__(
input_feat_per_channel=input_feat_per_channel,
vggblock_config=vggblock_config,
transformer_config=transformer_config,
encoder_output_dim=encoder_output_dim,
in_channels=in_channels,
transformer_context=transformer_context,
transformer_sampling=transformer_sampling,
)
self.fc_out = Linear(self.encoder_output_dim, vocab_size)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
enc_out = super().forward(src_tokens, src_lengths)
x = self.fc_out(enc_out["encoder_out"])
# x = F.log_softmax(x, dim=-1)
# Note: no need this line, because model.get_normalized_prob will call
# log_softmax
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B)
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
return (1e6, 1e6) # an arbitrary large number
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
# nn.init.uniform_(m.weight, -0.1, 0.1)
# nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True, dropout=0):
"""Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
# m.weight.data.uniform_(-0.1, 0.1)
# if bias:
# m.bias.data.uniform_(-0.1, 0.1)
return m
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m, dim=2)
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
# seq2seq models
def base_architecture(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.in_channels = getattr(args, "in_channels", 1)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.transformer_dec_config = getattr(
args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG)
args.transformer_context = getattr(args, "transformer_context", "None")
@register_model_architecture("asr_vggtransformer", "vggtransformer_1")
def vggtransformer_1(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_2")
def vggtransformer_2(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_base")
def vggtransformer_base(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12"
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6"
)
# Size estimations:
# Encoder:
# - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K
# Transformer:
# - input dimension adapter: 2560 x 512 -> 1.31M
# - transformer_layers (x12) --> 37.74M
# * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M
# * FFN weight: 512*2048*2 = 2.097M
# - output dimension adapter: 512 x 512 -> 0.26 M
# Decoder:
# - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3
# - transformer_layer: (x6) --> 25.16M
# * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M
# * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M
# * FFN: 512*2048*2 = 2.097M
# Final FC:
# - FC: 512*5000 = 256K (assuming vocab size 5K)
# In total:
# ~65 M
# CTC models
def base_architecture_enconly(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2"
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2"
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.in_channels = getattr(args, "in_channels", 1)
args.transformer_context = getattr(args, "transformer_context", "None")
args.transformer_sampling = getattr(args, "transformer_sampling", "None")
@register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1")
def vggtransformer_enc_1(args):
# vggtransformer_1 is the same as vggtransformer_enc_big, except the number
# of layers is increased to 16
# keep it here for backward compatiablity purpose
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment