"megatron/legacy/model/module.py" did not exist on "dd889062646a74427e2cc525191ae072cea70734"
Commit a7785cc6 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

delete soft link

parent 9a2a05ca
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
try:
import k2
from icefall.utils import get_texts
from icefall.decode import get_lattice, Nbest, one_best_decoding
except ImportError:
print('Failed to import k2 and icefall. \
Notice that they are necessary for hlg_onebest and hlg_rescore')
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import TransformerEncoder
from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss
from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add,
remove_duplicates_and_blank, th_accuracy,
reverse_pad_list)
from wenet.utils.mask import (make_pad_mask, mask_finished_preds,
mask_finished_scores, subsequent_mask)
class ASRModel(torch.nn.Module):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
vocab_size: int,
encoder: TransformerEncoder,
decoder: TransformerDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
ignore_id: int = IGNORE_ID,
reverse_weight: float = 0.0,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight
self.encoder = encoder
self.decoder = decoder
self.ctc = ctc
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, Optional[torch.Tensor]]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
text.shape, text_lengths.shape)
# 1. Encoder
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# 2a. Attention-decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths)
else:
loss_att = None
# 2b. CTC branch
if self.ctc_weight != 0.0:
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
else:
loss_ctc = None
if loss_ctc is None:
loss = loss_att
elif loss_att is None:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 -
self.ctc_weight) * loss_att
return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc}
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_mask: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> Tuple[torch.Tensor, float]:
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# reverse the seq, used for right to left decoder
r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id))
r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos,
self.ignore_id)
# 1. Forward decoder
decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask,
ys_in_pad, ys_in_lens,
r_ys_in_pad,
self.reverse_weight)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
r_loss_att = torch.tensor(0.0)
if self.reverse_weight > 0.0:
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (
1 - self.reverse_weight) + r_loss_att * self.reverse_weight
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
return loss_att, acc_att
def _forward_encoder(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Let's assume B = batch_size
# 1. Encoder
if simulate_streaming and decoding_chunk_size > 0:
encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk(
speech,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
else:
encoder_out, encoder_mask = self.encoder(
speech,
speech_lengths,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
return encoder_out, encoder_mask
def recognize(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int = 10,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> torch.Tensor:
""" Apply beam search on attention decoder
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
torch.Tensor: decoding result, (batch, max_result_len)
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.device
batch_size = speech.shape[0]
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
encoder_dim = encoder_out.size(2)
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
encoder_mask = encoder_mask.unsqueeze(1).repeat(
1, beam_size, 1, 1).view(running_size, 1,
maxlen) # (B*N, 1, max_len)
hyps = torch.ones([running_size, 1], dtype=torch.long,
device=device).fill_(self.sos) # (B*N, 1)
scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1),
dtype=torch.float)
scores = scores.to(device).repeat([batch_size]).unsqueeze(1).to(
device) # (B*N, 1)
end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device)
cache: Optional[List[torch.Tensor]] = None
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos
if end_flag.sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
# 2.3 Second beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
# Update cache to be consistent with new topk scores / hyps
cache_index = (offset_k_index // beam_size).view(-1) # (B*N)
base_cache_index = (torch.arange(batch_size, device=device).view(
-1, 1).repeat([1, beam_size]) * beam_size).view(-1) # (B*N)
cache_index = base_cache_index + cache_index
cache = [torch.index_select(c, dim=0, index=cache_index) for c in cache]
scores = scores.view(-1, 1) # (B*N, 1)
# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index = torch.arange(batch_size, device=device).view(
-1, 1).repeat([1, beam_size]) # (B, N)
base_k_index = base_k_index * beam_size * beam_size
best_k_index = base_k_index.view(-1) + offset_k_index.view(
-1) # (B*N)
# 2.5 Update best hyps
best_k_pred = torch.index_select(top_k_index.view(-1),
dim=-1,
index=best_k_index) # (B*N)
best_hyps_index = best_k_index // beam_size
last_best_k_hyps = torch.index_select(
hyps, dim=0, index=best_hyps_index) # (B*N, i)
hyps = torch.cat((last_best_k_hyps, best_k_pred.view(-1, 1)),
dim=1) # (B*N, i+1)
# 2.6 Update end flag
end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1)
# 3. Select best of best
scores = scores.view(batch_size, beam_size)
# TODO: length normalization
best_scores, best_index = scores.max(dim=-1)
best_hyps_index = best_index + torch.arange(
batch_size, dtype=torch.long, device=device) * beam_size
best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index)
best_hyps = best_hyps[:, 1:]
return best_hyps, best_scores
def ctc_greedy_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> List[List[int]]:
""" Apply CTC greedy search
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: best path result
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
batch_size = speech.shape[0]
# Let's assume B = batch_size
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
ctc_probs = self.ctc.log_softmax(
encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
mask = make_pad_mask(encoder_out_lens, maxlen) # (B, maxlen)
topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
scores = topk_prob.max(1)
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
return hyps, scores
def _ctc_prefix_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: nbest results
torch.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
batch_size = speech.shape[0]
# For CTC prefix beam search, we only support batch_size=1
assert batch_size == 1
# Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
return hyps, encoder_out
def ctc_prefix_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> List[int]:
""" Apply CTC prefix beam search
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps, _ = self._ctc_prefix_beam_search(speech, speech_lengths,
beam_size, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming)
return hyps[0]
def attention_rescoring(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
ctc_weight: float = 0.0,
simulate_streaming: bool = False,
reverse_weight: float = 0.0,
) -> List[int]:
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
reverse_weight (float): right to left decoder weight
ctc_weight (float): ctc score weight
Returns:
List[int]: Attention rescoring result
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
if reverse_weight > 0.0:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert hasattr(self.decoder, 'right_decoder')
device = speech.device
batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1
assert batch_size == 1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
hyps, encoder_out = self._ctc_prefix_beam_search(
speech, speech_lengths, beam_size, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
assert len(hyps) == beam_size
hyps_pad = pad_sequence([
torch.tensor(hyp[0], device=device, dtype=torch.long)
for hyp in hyps
], True, self.ignore_id) # (beam_size, max_hyps_len)
ori_hyps_pad = hyps_pad
hyps_lens = torch.tensor([len(hyp[0]) for hyp in hyps],
device=device,
dtype=torch.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = torch.ones(beam_size,
1,
encoder_out.size(1),
dtype=torch.bool,
device=device)
# used for right to left decoder
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
self.ignore_id)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
decoder_out = decoder_out.cpu().numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
r_decoder_out = r_decoder_out.cpu().numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
score += decoder_out[i][len(hyp[0])][self.eos]
# add right to left decoder score
if reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score
score += hyp[1] * ctc_weight
if score > best_score:
best_score = score
best_index = i
return hyps[best_index][0], best_score
def load_hlg_resource_if_necessary(self, hlg, word):
if not hasattr(self, 'hlg'):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.hlg = k2.Fsa.from_dict(torch.load(hlg, map_location=device))
if not hasattr(self.hlg, "lm_scores"):
self.hlg.lm_scores = self.hlg.scores.clone()
if not hasattr(self, 'word_table'):
self.word_table = {}
with open(word, 'r') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
self.word_table[int(arr[1])] = arr[0]
@torch.no_grad()
def hlg_onebest(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
hlg: str = '',
word: str = '',
symbol_table: Dict[str, int] = None,
) -> List[int]:
self.load_hlg_resource_if_necessary(hlg, word)
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
ctc_probs = self.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
supervision_segments = torch.stack(
(torch.arange(len(encoder_mask)),
torch.zeros(len(encoder_mask)),
encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), 1,).to(torch.int32)
lattice = get_lattice(
nnet_output=ctc_probs,
decoding_graph=self.hlg,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=7,
min_active_states=30,
max_active_states=10000,
subsampling_factor=4)
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
hyps = get_texts(best_path)
hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] for i in hyps]
return hyps
@torch.no_grad()
def hlg_rescore(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
lm_scale: float = 0,
decoder_scale: float = 0,
r_decoder_scale: float = 0,
hlg: str = '',
word: str = '',
symbol_table: Dict[str, int] = None,
) -> List[int]:
self.load_hlg_resource_if_necessary(hlg, word)
device = speech.device
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
ctc_probs = self.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
supervision_segments = torch.stack(
(torch.arange(len(encoder_mask)),
torch.zeros(len(encoder_mask)),
encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), 1,).to(torch.int32)
lattice = get_lattice(
nnet_output=ctc_probs,
decoding_graph=self.hlg,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=7,
min_active_states=30,
max_active_states=10000,
subsampling_factor=4)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=100,
use_double_scores=True,
nbest_scale=0.5,)
nbest = nbest.intersect(lattice)
assert hasattr(nbest.fsa, "lm_scores")
assert hasattr(nbest.fsa, "tokens")
assert isinstance(nbest.fsa.tokens, torch.Tensor)
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
hyps = tokens.tolist()
# cal attention_score
hyps_pad = pad_sequence([
torch.tensor(hyp, device=device, dtype=torch.long)
for hyp in hyps
], True, self.ignore_id) # (beam_size, max_hyps_len)
ori_hyps_pad = hyps_pad
hyps_lens = torch.tensor([len(hyp) for hyp in hyps],
device=device,
dtype=torch.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out_repeat = []
tot_scores = nbest.tot_scores()
repeats = [tot_scores[i].shape[0] for i in range(tot_scores.dim0)]
for i in range(len(encoder_out)):
encoder_out_repeat.append(encoder_out[i: i + 1].repeat(repeats[i], 1, 1))
encoder_out = torch.concat(encoder_out_repeat, dim=0)
encoder_mask = torch.ones(encoder_out.size(0),
1,
encoder_out.size(1),
dtype=torch.bool,
device=device)
# used for right to left decoder
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
self.ignore_id)
reverse_weight = 0.5
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
decoder_out = decoder_out
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
r_decoder_out = r_decoder_out
decoder_scores = torch.tensor([sum([decoder_out[i, j, hyps[i][j]]
for j in range(len(hyps[i]))])
for i in range(len(hyps))], device=device)
r_decoder_scores = []
for i in range(len(hyps)):
score = 0
for j in range(len(hyps[i])):
score += r_decoder_out[i, len(hyps[i]) - j - 1, hyps[i][j]]
score += r_decoder_out[i, len(hyps[i]), self.eos]
r_decoder_scores.append(score)
r_decoder_scores = torch.tensor(r_decoder_scores, device=device)
am_scores = nbest.compute_am_scores()
ngram_lm_scores = nbest.compute_lm_scores()
tot_scores = am_scores.values + lm_scale * ngram_lm_scores.values + \
decoder_scale * decoder_scores + r_decoder_scale * r_decoder_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] for i in hyps]
return hyps
@torch.jit.export
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@torch.jit.export
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@torch.jit.export
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@torch.jit.export
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
@torch.jit.export
def forward_encoder_chunk(
self,
xs: torch.Tensor,
offset: int,
required_cache_size: int,
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return self.encoder.forward_chunk(xs, offset, required_cache_size,
att_cache, cnn_cache)
@torch.jit.export
def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (torch.Tensor): encoder output
Returns:
torch.Tensor: activation before ctc
"""
return self.ctc.log_softmax(xs)
@torch.jit.export
def is_bidirectional_decoder(self) -> bool:
"""
Returns:
torch.Tensor: decoder output
"""
if hasattr(self.decoder, 'right_decoder'):
return True
else:
return False
@torch.jit.export
def forward_attention_decoder(
self,
hyps: torch.Tensor,
hyps_lens: torch.Tensor,
encoder_out: torch.Tensor,
reverse_weight: float = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
hyps (torch.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining
hyps_lens (torch.Tensor): length of each hyp in hyps
encoder_out (torch.Tensor): corresponding encoder output
r_hyps (torch.Tensor): hyps from ctc prefix beam search, already
pad eos at the begining which is used fo right to left decoder
reverse_weight: used for verfing whether used right to left decoder,
> 0 will use.
Returns:
torch.Tensor: decoder output
"""
assert encoder_out.size(0) == 1
num_hyps = hyps.size(0)
assert hyps_lens.size(0) == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
encoder_mask = torch.ones(num_hyps,
1,
encoder_out.size(1),
dtype=torch.bool,
device=encoder_out.device)
# input for right to left decoder
# this hyps_lens has count <sos> token, we need minus it.
r_hyps_lens = hyps_lens - 1
# this hyps has included <sos> token, so it should be
# convert the original hyps.
r_hyps = hyps[:, 1:]
# >>> r_hyps
# >>> tensor([[ 1, 2, 3],
# >>> [ 9, 8, 4],
# >>> [ 2, -1, -1]])
# >>> r_hyps_lens
# >>> tensor([3, 3, 1])
# NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used
# in `reverse_pad_list` thus we have to refine the below code.
# Issue: https://github.com/wenet-e2e/wenet/issues/1113
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
max_len = torch.max(r_hyps_lens)
index_range = torch.arange(0, max_len, 1).to(encoder_out.device)
seq_len_expand = r_hyps_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)
# >>> seq_mask
# >>> tensor([[ True, True, True],
# >>> [ True, True, True],
# >>> [ True, False, False]])
index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index = index * seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
r_hyps = torch.gather(r_hyps, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps = torch.where(seq_mask, r_hyps, self.eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps,
reverse_weight) # (num_hyps, max_hyps_len, vocab_size)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
# right to left decoder may be not used during decoding process,
# which depends on reverse_weight param.
# r_dccoder_out will be 0.0, if reverse_weight is 0.0
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
return decoder_out, r_decoder_out
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Head Attention layer definition."""
import math
from typing import Tuple
import torch
from torch import nn
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(
self, value: torch.Tensor, scores: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if mask.size(2) > 0 : # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
Wenet.
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(
cache, cache.size(-1) // 2, dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x, zero_triu: bool = False):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(
cache, cache.size(-1) // 2, dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class GlobalCMVN(torch.nn.Module):
def __init__(self,
mean: torch.Tensor,
istd: torch.Tensor,
norm_var: bool = True):
"""
Args:
mean (torch.Tensor): mean stats
istd (torch.Tensor): inverse std, std which is 1.0 / std
"""
super().__init__()
assert mean.shape == istd.shape
self.norm_var = norm_var
# The buffer can be accessed from this module using self.mean
self.register_buffer("mean", mean)
self.register_buffer("istd", istd)
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): (batch, max_len, feat_dim)
Returns:
(torch.Tensor): normalized feature
"""
x = x - self.mean
if self.norm_var:
x = x * self.istd
return x
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""ConvolutionModule definition."""
from typing import Tuple
import torch
from torch import nn
from typeguard import check_argument_types
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
assert check_argument_types()
super().__init__()
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(
self,
x: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
cache: torch.Tensor = torch.zeros((0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) # (#batch, channels, time)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
if self.lorder > 0:
if cache.size(2) == 0: # cache_t == 0
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
else:
assert cache.size(0) == x.size(0) # equal batch
assert cache.size(1) == x.size(1) # equal channel
x = torch.cat((cache, x), dim=2)
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
class CTC(torch.nn.Module):
"""CTC module"""
def __init__(
self,
odim: int,
encoder_output_size: int,
dropout_rate: float = 0.0,
reduce: bool = True,
):
""" Construct CTC module
Args:
odim: dimension of outputs
encoder_output_size: number of encoder projection units
dropout_rate: dropout rate (0.0 ~ 1.0)
reduce: reduce the CTC loss into a scalar
"""
assert check_argument_types()
super().__init__()
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
self.ctc_lo = torch.nn.Linear(eprojs, odim)
reduction_type = "sum" if reduce else "none"
self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor,
ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor:
"""Calculate CTC loss.
Args:
hs_pad: batch of padded hidden state sequences (B, Tmax, D)
hlens: batch of lengths of hidden state sequences (B)
ys_pad: batch of padded character id sequence tensor (B, Lmax)
ys_lens: batch of lengths of character sequence (B)
"""
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
# ys_hat: (B, L, D) -> (L, B, D)
ys_hat = ys_hat.transpose(0, 1)
ys_hat = ys_hat.log_softmax(2)
loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens)
# Batch-size average
loss = loss / ys_hat.size(1)
return loss
def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor:
"""log_softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor:
"""argmax of frame activations
Args:
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: argmax applied 2d tensor (B, Tmax)
"""
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Decoder definition."""
from typing import Tuple, List, Optional
import torch
from typeguard import check_argument_types
from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.transformer.embedding import PositionalEncoding
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
class TransformerDecoder(torch.nn.Module):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after: whether to concat attention layer's input and output
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
PositionalEncoding(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' is supported: {input_layer}")
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
self.use_output_layer = use_output_layer
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
self.num_blocks = num_blocks
self.decoders = torch.nn.ModuleList([
DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
concat_after,
) for _ in range(self.num_blocks)
])
def forward(
self,
memory: torch.Tensor,
memory_mask: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
r_ys_in_pad: torch.Tensor = torch.empty(0),
reverse_weight: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: not used in transformer decoder, in order to unify api
with bidirectional decoder
reverse_weight: not used in transformer decoder, in order to unify
api with bidirectional decode
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
torch.tensor(0.0), in order to unify api with bidirectional decoder
olens: (batch, )
"""
tgt = ys_in_pad
maxlen = tgt.size(1)
# tgt_mask: (B, 1, L)
tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
tgt_mask = tgt_mask.to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1),
device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
x, _ = self.embed(tgt)
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.use_output_layer:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, torch.tensor(0.0), olens
def forward_one_step(
self,
memory: torch.Tensor,
memory_mask: torch.Tensor,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x, _ = self.embed(tgt)
new_cache = []
for i, decoder in enumerate(self.decoders):
if cache is None:
c = None
else:
c = cache[i]
x, tgt_mask, memory, memory_mask = decoder(x,
tgt_mask,
memory,
memory_mask,
cache=c)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.use_output_layer:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
class BiTransformerDecoder(torch.nn.Module):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
r_num_blocks: the number of right to left decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after: whether to concat attention layer's input and output
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
r_num_blocks: int = 0,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__()
self.left_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before, concat_after)
self.right_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
r_num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before, concat_after)
def forward(
self,
memory: torch.Tensor,
memory_mask: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
r_ys_in_pad: torch.Tensor,
reverse_weight: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
used for right to left decoder
reverse_weight: used for right to left decoder
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
r_x: x: decoded token score (right to left decoder)
before softmax (batch, maxlen_out, vocab_size)
if use_output_layer is True,
olens: (batch, )
"""
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
ys_in_lens)
r_x = torch.tensor(0.0)
if reverse_weight > 0.0:
r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad,
ys_in_lens)
return l_x, r_x, olens
def forward_one_step(
self,
memory: torch.Tensor,
memory_mask: torch.Tensor,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
tgt_mask, cache)
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder self-attention layer definition."""
from typing import Optional, Tuple
import torch
from torch import nn
class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Inter-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's inpu
and output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(
self,
size: int,
self_attn: nn.Module,
src_attn: nn.Module,
feed_forward: nn.Module,
dropout_rate: float,
normalize_before: bool = True,
concat_after: bool = False,
):
"""Construct an DecoderLayer object."""
super().__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
self.norm3 = nn.LayerNorm(size, eps=1e-5)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
else:
self.concat_linear1 = nn.Identity()
self.concat_linear2 = nn.Identity()
def forward(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
cache: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor
(#batch, maxlen_out).
memory (torch.Tensor): Encoded memory
(#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask
(#batch, maxlen_in).
cache (torch.Tensor): cached tensors.
(#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask)[0])
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Positonal Encoding Module."""
import math
from typing import Tuple, Union
import torch
import torch.nn.functional as F
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
"""
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int = 5000,
reverse: bool = False):
"""Construct an PositionalEncoding object."""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.max_len = max_len
self.pe = torch.zeros(self.max_len, self.d_model)
position = torch.arange(0, self.max_len,
dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.d_model))
self.pe[:, 0::2] = torch.sin(position * div_term)
self.pe[:, 1::2] = torch.cos(position * div_term)
self.pe = self.pe.unsqueeze(0)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
offset (int, torch.tensor): position offset
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
torch.Tensor: for compatibility to RelPositionalEncoding
"""
self.pe = self.pe.to(x.device)
pos_emb = self.position_encoding(offset, x.size(1), False)
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self, offset: Union[int, torch.Tensor], size: int,
apply_dropout: bool = True) -> torch.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int or torch.tensor): start offset
size (int): required size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
# How to subscript a Union type:
# https://github.com/pytorch/pytorch/issues/69434
if isinstance(offset, int):
assert offset + size < self.max_len
pos_emb = self.pe[:, offset:offset + size]
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
assert offset + size < self.max_len
pos_emb = self.pe[:, offset:offset + size]
else: # for batched streaming decoding on GPU
assert torch.max(offset) + size < self.max_len
index = offset.unsqueeze(1) + \
torch.arange(0, size).to(offset.device) # B X T
flag = index > 0
# remove negative offset
index = index * flag
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
if apply_dropout:
pos_emb = self.dropout(pos_emb)
return pos_emb
class RelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
"""Initialize class."""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.pe = self.pe.to(x.device)
x = x * self.xscale
pos_emb = self.position_encoding(offset, x.size(1), False)
return self.dropout(x), self.dropout(pos_emb)
class NoPositionalEncoding(torch.nn.Module):
""" No position encoding
"""
def __init__(self, d_model: int, dropout_rate: float):
super().__init__()
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
""" Just return zero vector for interface compatibility
"""
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
return self.dropout(x), pos_emb
def position_encoding(
self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
return torch.zeros(1, size, self.d_model)
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from typing import Tuple
import torch
from typeguard import check_argument_types
from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.attention import RelPositionMultiHeadedAttention
from wenet.transformer.convolution import ConvolutionModule
from wenet.transformer.embedding import PositionalEncoding
from wenet.transformer.embedding import RelPositionalEncoding
from wenet.transformer.embedding import NoPositionalEncoding
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.transformer.subsampling import Conv2dSubsampling4
from wenet.transformer.subsampling import Conv2dSubsampling6
from wenet.transformer.subsampling import Conv2dSubsampling8
from wenet.transformer.subsampling import LinearNoSubsampling
from wenet.utils.common import get_activation
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
class BaseEncoder(torch.nn.Module):
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
concat_after: bool = False,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
):
"""
Args:
input_size (int): input dim
output_size (int): dimension of attention
attention_heads (int): the number of heads of multi head attention
linear_units (int): the hidden units number of position-wise feed
forward
num_blocks (int): the number of decoder blocks
dropout_rate (float): dropout rate
attention_dropout_rate (float): dropout rate in attention
positional_dropout_rate (float): dropout rate after adding
positional encoding
input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8]
pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
normalize_before (bool):
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after (bool): whether to concat attention layer's input
and output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
static_chunk_size (int): chunk size for static chunk training and
decoding
use_dynamic_chunk (bool): whether use dynamic chunk size for
training or not, You can only use fixed chunk(chunk_size > 0)
or dyanmic chunk size(use_dynamic_chunk = True)
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
dynamic chunk training
"""
assert check_argument_types()
super().__init__()
self._output_size = output_size
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
subsampling_class = LinearNoSubsampling
elif input_layer == "conv2d":
subsampling_class = Conv2dSubsampling4
elif input_layer == "conv2d6":
subsampling_class = Conv2dSubsampling6
elif input_layer == "conv2d8":
subsampling_class = Conv2dSubsampling8
else:
raise ValueError("unknown input_layer: " + input_layer)
self.global_cmvn = global_cmvn
self.embed = subsampling_class(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
def forward_chunk(
self,
xs: torch.Tensor,
offset: int,
required_cache_size: int,
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
assert xs.size(0) == 1
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
chunk_size = xs.size(1)
attention_key_size = cache_t1 + chunk_size
pos_emb = self.embed.position_encoding(
offset=offset - cache_t1, size=attention_key_size)
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = attention_key_size
else:
next_cache_start = max(attention_key_size - required_cache_size, 0)
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoders):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache
)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
if self.normalize_before:
xs = self.after_norm(xs)
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache = torch.cat(r_att_cache, dim=0)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
return (xs, r_att_cache, r_cnn_cache)
def forward_chunk_by_chunk(
self,
xs: torch.Tensor,
decoding_chunk_size: int,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size
"""
assert decoding_chunk_size > 0
# The model is trained by static or dynamic chunk
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
subsampling = self.embed.subsampling_rate
context = self.embed.right_context + 1 # Add current frame
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
outputs = []
offset = 0
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
outputs.append(y)
offset += y.size(1)
ys = torch.cat(outputs, 1)
masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool)
return ys, masks
class TransformerEncoder(BaseEncoder):
"""Transformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
concat_after: bool = False,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
):
""" Construct TransformerEncoder
See Encoder for the meaning of each parameter.
"""
assert check_argument_types()
super().__init__(input_size, output_size, attention_heads,
linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
concat_after, static_chunk_size, use_dynamic_chunk,
global_cmvn, use_dynamic_left_chunk)
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
MultiHeadedAttention(attention_heads, output_size,
attention_dropout_rate),
PositionwiseFeedForward(output_size, linear_units,
dropout_rate), dropout_rate,
normalize_before, concat_after) for _ in range(num_blocks)
])
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "rel_pos",
normalize_before: bool = True,
concat_after: bool = False,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
positionwise_conv_kernel_size: int = 1,
macaron_style: bool = True,
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
):
"""Construct ConformerEncoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
positionwise_conv_kernel_size (int): Kernel size of positionwise
conv1d layer.
macaron_style (bool): Whether to use macaron style for
positionwise layer.
selfattention_layer_type (str): Encoder attention layer type,
the parameter has no effect now, it's just for configure
compatibility.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): whether to use causal convolution or not.
"""
assert check_argument_types()
super().__init__(input_size, output_size, attention_heads,
linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
concat_after, static_chunk_size, use_dynamic_chunk,
global_cmvn, use_dynamic_left_chunk)
activation = get_activation(activation_type)
# self-attention module definition
if pos_enc_layer_type != "rel_pos":
encoder_selfattn_layer = MultiHeadedAttention
else:
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(
*positionwise_layer_args) if macaron_style else None,
convolution_layer(
*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
) for _ in range(num_blocks)
])
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder self-attention layer definition."""
from typing import Optional, Tuple
import torch
from torch import nn
class TransformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward: torch.nn.Module,
dropout_rate: float,
normalize_before: bool = True,
concat_after: bool = False,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if concat_after:
self.concat_linear = nn.Linear(size + size, size)
else:
self.concat_linear = nn.Identity()
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): just for interface compatibility
to ConformerEncoderLayer
mask_pad (torch.Tensor): does not used in transformer layer,
just for unified api with conformer.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2), not used here, it's for interface
compatibility to ConformerEncoderLayer.
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, cache=att_cache)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
return x, mask, new_att_cache, fake_cnn_cache
class ConformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward: Optional[nn.Module] = None,
feed_forward_macaron: Optional[nn.Module] = None,
conv_module: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
concat_after: bool = False,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(size,
eps=1e-5) # for the CNN module
self.norm_final = nn.LayerNorm(
size, eps=1e-5) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
else:
self.concat_linear = nn.Identity()
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, pos_emb, att_cache)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
return x, mask, new_att_cache, new_cnn_cache
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Label smoothing module."""
import torch
from torch import nn
class LabelSmoothingLoss(nn.Module):
"""Label-smoothing loss.
In a standard CE loss, the label's data distribution is:
[0,1,2] ->
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
In the smoothing version CE Loss,some probabilities
are taken from the true label prob (1.0) and are divided
among other labels.
e.g.
smoothing=0.1
[0,1,2] ->
[
[0.9, 0.05, 0.05],
[0.05, 0.9, 0.05],
[0.05, 0.05, 0.9],
]
Args:
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool):
normalize loss by sequence length if True
normalize loss by batch size if False
"""
def __init__(self,
size: int,
padding_idx: int,
smoothing: float,
normalize_length: bool = False):
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = nn.KLDivLoss(reduction="none")
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.normalize_length = normalize_length
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute loss between x and target.
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (torch.Tensor): prediction (batch, seqlen, class)
target (torch.Tensor):
target signal masked with self.padding_id (batch, seqlen)
Returns:
loss (torch.Tensor) : The KL loss, scalar float value
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist = torch.zeros_like(x)
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import torch
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def __init__(self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU()):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.activation = activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.w_2 = torch.nn.Linear(hidden_units, idim)
def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Subsampling layer definition."""
from typing import Tuple, Union
import torch
class BaseSubsampling(torch.nn.Module):
def __init__(self):
super().__init__()
self.right_context = 0
self.subsampling_rate = 1
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return self.pos_enc.position_encoding(offset, size)
class LinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an linear object."""
super().__init__()
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
torch.nn.LayerNorm(odim, eps=1e-5),
torch.nn.Dropout(dropout_rate),
)
self.pos_enc = pos_enc_class
self.right_context = 0
self.subsampling_rate = 1
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class Conv2dSubsampling4(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling4 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 4
# 6 = (3 - 1) * 1 + (3 - 1) * 2
self.right_context = 6
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
class Conv2dSubsampling6(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling6 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
odim)
self.pos_enc = pos_enc_class
# 10 = (3 - 1) * 1 + (5 - 1) * 2
self.subsampling_rate = 6
self.right_context = 10
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
class Conv2dSubsampling8(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling8 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
self.pos_enc = pos_enc_class
self.subsampling_rate = 8
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
self.right_context = 14
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Swish() activation function for Conformer."""
import torch
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swish activation function."""
return x * torch.sigmoid(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment