Commit 12c90639 authored by “change”'s avatar “change”
Browse files

init

parent 417b607b
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/tasks/translation.py
1. Add custom lang_format in function load_langpair_dataset
2. If truncate_source (default no), use RandomCropDataset instead of TruncateDataset
"""
import itertools
import logging
import os
from fairseq.data import (
AppendTokenDataset,
LanguagePairDataset,
PrependTokenDataset,
StripTokenDataset,
TruncateDataset,
RandomCropDataset,
data_utils,
indexed_dataset,
)
from speechut.data.concat_dataset import ConcatDataset
EVAL_BLEU_ORDER = 4
logger = logging.getLogger(__name__)
def load_langpair_dataset(
data_path,
split,
src,
src_dict,
tgt,
tgt_dict,
combine,
dataset_impl,
upsample_primary,
left_pad_source,
left_pad_target,
max_source_positions,
max_target_positions,
prepend_bos=False,
load_alignments=False,
truncate_source=False,
append_source_id=False,
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
prepend_bos_src=None,
lang_format="[{}]",
input_feeding=True,
):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
# infer langcode
if split_exists(split_k, src, tgt, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
src_dataset = data_utils.load_indexed_dataset(
prefix + src, src_dict, dataset_impl
)
if truncate_source:
src_dataset = AppendTokenDataset(
RandomCropDataset(
StripTokenDataset(src_dataset, src_dict.eos()),
max_source_positions - 1,
),
src_dict.eos(),
)
src_datasets.append(src_dataset)
tgt_dataset = data_utils.load_indexed_dataset(
prefix + tgt, tgt_dict, dataset_impl
)
if tgt_dataset is not None:
tgt_datasets.append(tgt_dataset)
logger.info(
"{} {} {}-{} {} examples".format(
data_path, split_k, src, tgt, len(src_datasets[-1])
)
)
if not combine:
break
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
if len(src_datasets) == 1:
src_dataset = src_datasets[0]
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
else:
sample_ratios = [1] * len(src_datasets)
sample_ratios[0] = upsample_primary
src_dataset = ConcatDataset(src_datasets, sample_ratios)
if len(tgt_datasets) > 0:
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
else:
tgt_dataset = None
if prepend_bos:
assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
if tgt_dataset is not None:
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
elif prepend_bos_src is not None:
logger.info(f"prepending src bos: {prepend_bos_src}")
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
eos = None
if append_source_id:
src_dataset = AppendTokenDataset(
src_dataset, src_dict.index(lang_format.format(src))
)
if tgt_dataset is not None:
tgt_dataset = AppendTokenDataset(
tgt_dataset, tgt_dict.index(lang_format.format(tgt))
)
eos = tgt_dict.index(lang_format.format(tgt))
align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(
align_path, None, dataset_impl
)
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
return LanguagePairDataset(
src_dataset,
src_dataset.sizes,
src_dict,
tgt_dataset,
tgt_dataset_sizes,
tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
align_dataset=align_dataset,
eos=eos,
num_buckets=num_buckets,
shuffle=shuffle,
pad_to_multiple=pad_to_multiple,
input_feeding=input_feeding,
)
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import logging
from os import replace
import time
from collections import OrderedDict
from typing import Any, Dict, List, Optional
import numpy as np
from fairseq.data import data_utils
from fairseq.data import FairseqDataset
logger = logging.getLogger(__name__)
class MultiCorpusDataset(FairseqDataset):
"""
see fairseq/fairseq/data/multi_corpus_dataset.__doc__
Args:
datasets: a OrderedDict of FairseqDataset instances.
distribution: a List containing the probability of getting an utterance from
corresponding dataset
seed: random seed for sampling the datsets
sort_indices: if true, will sort the ordered indices by size
batch_sample: if true, will ensure each batch is from a single dataset
"""
def __init__(
self,
datasets: Dict[str, FairseqDataset],
max_positions: Dict,
distribution: List[float],
max_tokens_ratio: List[float],
seed: int = 1234,
sort_indices: bool = False,
check_length: bool = False,
):
super().__init__()
assert isinstance(datasets, OrderedDict)
assert len(datasets) == len(distribution)
# assert sum(distribution) == 1
self.datasets = datasets
self.distribution = distribution
self.max_tokens_ratio = max_tokens_ratio
self.seed = seed
self.sort_indices = sort_indices
self.max_positions = max_positions
self.check_length = check_length
# Avoid repeated conversions to list later
self.dataset_list = list(datasets.values())
self.total_num_instances = 0
# first_dataset = self.dataset_list[0]
self.num_instances_per_dataset = []
self.dataset_offsets = []
for i, dataset in enumerate(self.dataset_list):
assert isinstance(dataset, FairseqDataset)
# assert type(dataset) is type(first_dataset)
self.num_instances_per_dataset.append(
0 if self.distribution[i] == 0 else len(dataset)
)
self.dataset_offsets.append(self.total_num_instances)
self.total_num_instances += self.num_instances_per_dataset[i]
def ordered_indices(self):
start = time.time()
with data_utils.numpy_seed(self.seed, self.epoch):
logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}")
sampled_indices = {}
# For each dataset i, sample self.distribution[i] * self.total_num_instances
for i, key in enumerate(self.datasets):
tp = time.time()
if self.distribution[i] == 0:
# skip dataset if sampling probability is 0
continue
if i < len(self.datasets) - 1:
num_instances = int(self.distribution[i] * self.total_num_instances)
high = self.dataset_offsets[i + 1]
else:
num_instances = int(self.distribution[i] * self.total_num_instances)
high = self.total_num_instances
logger.info(f"sampling {num_instances} from {key} dataset")
# First, add k copies of the dataset where k = num_instances // len(dataset).
# This ensures an equal distribution of the data points as much as possible.
# For the remaining entries randomly sample them
dataset_size = len(self.datasets[key])
num_copies = num_instances // dataset_size
dataset_indices = np.random.permutation(high - self.dataset_offsets[i])[: num_instances - num_copies * dataset_size]
if num_copies > 0:
dataset_indices = np.concatenate(
(
np.repeat(
np.arange(high - self.dataset_offsets[i]), num_copies
),
dataset_indices,
)
)
# filter by size, we should ignore it by setting check_length=False
# , as it is very time-consuming on large dadaset
if self.max_positions[key] is not None and self.check_length:
dataset_indices, ignored = self.datasets[key].filter_indices_by_size(
dataset_indices,
self.max_positions[key],
)
if len(ignored) > 0:
logger.warning(
(
"{:,} samples have invalid sizes and will be skipped, "
"max_positions={}, first few sample ids={}"
).format(len(ignored), self.max_positions[key], ignored[:10])
)
if self.sort_indices:
logger.info(" - sampled indices took {}s".format(time.time() - tp))
tp = time.time()
dataset_indices = np.sort(dataset_indices)
ordered_indices = self.datasets[key].ordered_indices()
if isinstance(ordered_indices[0], np.ndarray): # chunked audio data
dataset_indices = [order_idx + self.dataset_offsets[i] for order_idx in ordered_indices]
assert self.dataset_offsets[i] == 0
# TODO for chunked audio data, now assume len(dataset_indices) == len(dataset). Don't filter any data.
else:
dataset_indices = ordered_indices[dataset_indices] + self.dataset_offsets[i]
logger.info(" - ordered_indices took {}s".format(time.time() - tp))
else:
np.random.shuffle(dataset_indices)
sampled_indices[key] = dataset_indices
logger.info(
"multi_corpus_dataset ordered_indices took {}s".format(
time.time() - start
)
)
return sampled_indices
def _map_index(self, index: int):
"""
If dataset A has length N and dataset B has length M
then index 1 maps to index 1 of dataset A, and index N + 1
maps to index 1 of B.
"""
counter = 0
for num_instances, key in zip(self.num_instances_per_dataset, self.datasets):
if index < counter + num_instances:
return index - counter, key
counter += num_instances
raise ValueError(
"Invalid index: {}, max: {}".format(index, self.total_num_instances)
)
def __len__(self):
"""
Length of this dataset is the sum of individual datasets
"""
return self.total_num_instances
def __getitem__(self, index):
new_index, key = self._map_index(index)
try:
item = self.datasets[key][new_index]
item["full_id"] = index
return item
except Exception as e:
e.args = (f"Error from {key} dataset", *e.args)
raise
def collater(self, samples):
"""
If we are doing batch sampling, then pick the right collater to use.
Otherwise we assume all collaters are the same.
"""
if len(samples) == 0:
return None
samples_dict = {key: [] for key in self.datasets}
for s in samples:
_, key = self._map_index(s["full_id"])
samples_dict[key].append(s)
batch = {}
for key in samples_dict:
if len(samples_dict[key]) == 0:
continue
batch[key] = self.datasets[key].collater(samples_dict[key])
return batch
def num_tokens(self, index: int):
index, key = self._map_index(index)
return self.datasets[key].num_tokens(index)
def size(self, index: int):
index, key = self._map_index(index)
return self.datasets[key].size(index)
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
logger.info(f"setting epoch of multi_corpus_dataset to {epoch}")
for ds in self.dataset_list:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
self.epoch = epoch
@property
def supports_prefetch(self):
return False
@property
def supports_fetch_outside_dataloader(self):
return all(
self.datasets[key].supports_fetch_outside_dataloader
for key in self.datasets
)
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
dataset_indices = indices
batches_dict = {}
for n, key in enumerate(dataset_indices):
max_tokens_ratio = self.max_tokens_ratio[n]
if isinstance(dataset_indices[key][0], np.ndarray): # chunked audio data
cur_batches = self.datasets[key].batch_by_size(
dataset_indices[key],
round(max_tokens * max_tokens_ratio),
max_sentences,
required_batch_size_multiple,
)
logger.info(f"Created {sum([len(b) for b in cur_batches])} [{len(cur_batches)}] batches for dataset {key}")
else:
cur_batches = super().batch_by_size(
np.array(dataset_indices[key], dtype=np.int64),
round(max_tokens * max_tokens_ratio),
max_sentences,
required_batch_size_multiple,
)
logger.info(f"Created {len(cur_batches)} batches for dataset {key}")
batches_dict[key] = cur_batches
return batches_dict
def get_batch_sampler(
self,
indices,
num_shards,
seed,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
split_modality_batch=False,
):
def batch_sampler(dataset, epoch):
start = time.time()
batches_dict = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
logger.info(f"multi_corpus_dataset, batch_by_size took {time.time() - start}s")
start = time.time()
new_batches = []
### shuffle inner group size, split into speech/text batches
shuffled_batches_list = []
speech_batches = []
### we should specify the speech_batches because: we need concatenate different speech datasets
# (e.g. ltr or km) instead of loading them parellelly.
for name, batches in batches_dict.items():
if name.startswith("speech"):
if isinstance(batches[0], list): # chunked audio data
batches = self.datasets[name].shuffle_batches(list(batches), seed + epoch)
shuffled_batches_list.append(batches)
else:
batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
batches = batches[: (len(batches) // num_shards) * num_shards]
if len(batches) == 0:
logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
else:
speech_batches += batches
else:
batches = inner_bucket_shuffle(batches, seed+epoch, num_shards*10)
batches = batches[: (len(batches) // num_shards) * num_shards]
if len(batches) == 0:
logger.warning(f"Sample 0 batch for {name}, you should ensure that no {name} data provided.")
else:
batches = shuffle_buckets(batches, seed=seed+epoch, inner_shuf=False)
shuffled_batches_list.append(batches)
if len(speech_batches) > 0:
speech_batches = shuffle_buckets(speech_batches, seed=seed+epoch, inner_shuf=False)
shuffled_batches_list.append(speech_batches)
### create the final new_batches
num_batch = min(len(batches) for batches in shuffled_batches_list)
if split_modality_batch:
for i in range(0, num_batch, num_shards):
for batches in shuffled_batches_list:
new_batches += batches[i: i + num_shards]
else:
for i in range(num_batch):
new_batches.append(np.concatenate([batches[i] for batches in shuffled_batches_list]))
logger.info(f"multi_corpus_dataset sample {len(new_batches)} batches, took {time.time() - start}s")
return new_batches
def inner_bucket_shuffle(batches, seed, bucket_size=10, thr=0):
"""we assert batches is sorted form long to short.
shuffle samples in a buctet(e.g. 10 batches).
batches: a list of numpy array"""
num_batch = len(batches)
new_batches = []
num_buckets = len(batches) // bucket_size
i = 0
while i < num_batch:
if (i < bucket_size * thr or
i >= bucket_size * (num_buckets - thr)
):
new_batches.append(batches[i])
i += 1
else:
group = np.concatenate(batches[i: i+bucket_size])
with data_utils.numpy_seed(seed):
np.random.shuffle(group)
new_batches += np.array_split(group, bucket_size)
i += bucket_size
assert all([len(batch) > 0 for batch in new_batches])
return new_batches
def shuffle_buckets(batches, seed, inner_shuf=True):
if inner_shuf:
batches = inner_bucket_shuffle(batches, seed, num_shards*10)
batches = [batches[i: i + num_shards] for i in range(0, len(batches)-num_shards+1, num_shards)]
assert len(batches[-1]) == num_shards
new_batches = []
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
for group in batches:
new_batches += group
return new_batches
return batch_sampler
# ----------------------------------------------------------------------------
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils, checkpoint_utils
from fairseq.data.data_utils import compute_mask_indices
from fairseq.data.dictionary import Dictionary
from fairseq.dataclass import ChoiceEnum
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.transformer import Embedding
from fairseq.file_io import PathManager
from torch import Tensor
from fairseq.models.wav2vec.wav2vec2 import ConvFeatureExtractionModel
from fairseq.modules import GradMultiply, LayerNorm
from fairseq.tasks.hubert_pretraining import (
HubertPretrainingConfig,
HubertPretrainingTask,
)
from fairseq.models.hubert import HubertConfig
from fairseq.models.transformer import TransformerConfig
from speechut.modules import TransformerEncoder
from speechut.modules import TransformerEncoderBase
from speechut.modules import TransformerDecoderBaseScriptable
logger = logging.getLogger(__name__)
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
@dataclass
class SpeechutConfig(HubertConfig):
use_rel_pos_enc: bool = field(
default=False,
metadata={"help": "whether to use relative positional encoding"},
)
scaling_for_att: float = field(
default=1.0,
metadata={"help": "scaling for attention weights to prevent overflow issue (for large model)"},
)
# unit encoder-decoder
text_transformer: TransformerConfig = TransformerConfig()
reset_decoder_embedding_config: bool = field(
default=False,
metadata={"help": "reset the no_scale_embedding/layernorm_embedding to default for the decoder"},
)
add_unit_encoder: bool = field(
default=False,
metadata={"help": "add unit encoder"},
)
add_decoder: bool = field(
default=True,
metadata={"help": "add decoder"},
)
add_text_ctc: bool = field(
default=False,
metadata={"help": "add_text_ctc head"},
)
text_ctc_conv_kernel: int = field(
default=2,
metadata={"help": "text_ctc_conv kernel size"},
)
mask_u2t: bool = field(
default=True,
metadata={"help": "mask the unit input in unit-to-text task"},
)
# embedding mixing
mix_with_unit: bool = field(
default=True,
metadata={"help": "mix with the unit embeddings"},
)
use_pred_unit: bool = field(
default=False,
metadata={"help": "use the embeddings of predicted units"},
)
l2_embedding: bool = field(
default=False,
metadata={"help": "compute l2 loss between unit embedding and unit hidden state"},
)
# Finetune related
encoder_dict_size: int = field(
default=-1,
metadata={"help": "text encoder dictionary dimension"},
)
decoder_dict_size: int = field(
default=-1,
metadata={"help": "decoder dictionary dimension"},
)
@register_model("speechut", dataclass=SpeechutConfig)
class SpeechutModel(BaseFairseqModel):
def __init__(
self,
cfg: SpeechutConfig,
task_cfg: HubertPretrainingConfig,
dictionaries: List[Dictionary],
unit_dictionary: Dictionary = None,
text_tgt_dictionary: Dictionary = None,
) -> None:
super().__init__()
logger.info(f"SpeechutModel Config: {cfg}")
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
self.target_glu = None
if cfg.target_glu:
self.target_glu = nn.Sequential(
nn.Linear(final_dim, final_dim * 2), nn.GLU()
)
self.final_dim = final_dim
assert len(dictionaries) <= 2, f"Only support <=2 kinds of targets, get {len(dictionaries)} dictionaries"
if len(dictionaries) == 1:
dictionaries = [dictionaries[0], dictionaries[0]]
self.num_classes = [len(d) for d in dictionaries]
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
self.code_encoder_proj = nn.Linear(cfg.text_transformer.encoder.embed_dim, self.num_classes[-1])
self.final_proj_list = [self.final_proj, self.code_encoder_proj]
self.label_embs_concat = nn.Parameter(torch.FloatTensor(self.num_classes[0], final_dim))
self.label_embs_list = [self.label_embs_concat]
for p in self.label_embs_list:
nn.init.uniform_(p)
### build unit encoder:
self.mask_u2t = cfg.mask_u2t
self.add_text_ctc = cfg.add_text_ctc
self.text_ctc_conv_kernel = cfg.text_ctc_conv_kernel
self.padding_idx = unit_dictionary.pad()
self.unit_mask_idx = unit_dictionary.index("<mask>")
self.add_unit_encoder = cfg.add_unit_encoder
self.mix_with_unit = cfg.mix_with_unit
self.use_pred_unit = cfg.use_pred_unit
self.l2_embedding = cfg.l2_embedding
if self.add_unit_encoder:
assert len(unit_dictionary) == self.num_classes[0], f"unit_dictionary: {len(unit_dictionary)}, self.num_classes[0]: {self.num_classes[0]}"
### build unit pre-net, and shared with hubert label_embs if needed (default: False)
self.unit_embed_tokens = self.build_embedding(
unit_dictionary,
cfg.text_transformer.encoder.embed_dim,
)
if self.final_dim == cfg.text_transformer.encoder.embed_dim:
logger.info("Share label_embs[0] with unit_embed_tokens ...")
nn.init.uniform_(self.unit_embed_tokens.weight)
self.label_embs_list[0] = self.unit_embed_tokens.weight
### build unit encoder
self.unit_encoder = TransformerEncoderBase(
cfg.text_transformer,
unit_dictionary,
self.unit_embed_tokens,
use_rel_pos_enc=cfg.use_rel_pos_enc,
scaling_for_att=cfg.scaling_for_att,
)
### build text ctc head
if self.add_text_ctc:
conv = nn.Conv1d(
cfg.text_transformer.encoder.embed_dim, cfg.text_transformer.encoder.embed_dim,
self.text_ctc_conv_kernel,
stride=self.text_ctc_conv_kernel // 2,
bias=False,
padding=self.text_ctc_conv_kernel // 2,
)
nn.init.kaiming_normal_(conv.weight)
self.unit_encoder_ctc_head = nn.Sequential(
Rotate3D(),
conv,
nn.Dropout(p=0.1),
nn.Sequential(
Rotate3D(),
Rotate3D(),
LayerNorm(cfg.text_transformer.encoder.embed_dim),
),
nn.GELU(),
nn.Linear(cfg.text_transformer.encoder.embed_dim, len(text_tgt_dictionary)),
)
### build unit2text decoder, not available for now
self.add_decoder = cfg.add_decoder
self.text_transformer_cfg = cfg.text_transformer
if self.add_decoder:
# To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size or bpe code dict size
dec_dictionary = self.cutting_dictionary(text_tgt_dictionary, cfg.decoder_dict_size)
decoder_embed_tokens = self.build_embedding(
dec_dictionary, cfg.text_transformer.decoder.embed_dim
)
if cfg.reset_decoder_embedding_config:
cfg.text_transformer.no_scale_embedding = False
cfg.text_transformer.layernorm_embedding = False
cfg.text_transformer.no_token_positional_embeddings = False
self.decoder = TransformerDecoderBaseScriptable(cfg.text_transformer, dec_dictionary, decoder_embed_tokens, use_rel_pos_enc=cfg.use_rel_pos_enc)
def cutting_dictionary(self, dictionary, dict_size):
if dictionary is None or dict_size <= 0:
return dictionary
else:
import copy
cut_dictionary = copy.deepcopy(dictionary)
if dict_size > len(cut_dictionary):
for i in range(dict_size - len(cut_dictionary)):
cut_dictionary.symbols.append(f'_{i}_')
else:
cut_dictionary.symbols = cut_dictionary.symbols[:dict_size]
return cut_dictionary
def build_embedding(self, dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: SpeechutConfig, task: HubertPretrainingTask):
"""Build a new model instance."""
unit_dictionary = getattr(task, "text_src_dictionary", None)
text_tgt_dictionary = getattr(task, "text_dictionary", None)
model = SpeechutModel(cfg, task.cfg, task.dictionaries, unit_dictionary, text_tgt_dictionary)
return model
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
return features
def forward_targets(
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_inds += np.random.choice(int(self.feat2tar_ratio))
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
def downsample_ctc_padding_mask(self, padding_mask):
"""
padding_mask: (B, T)
"""
stride = self.text_ctc_conv_kernel // 2
return padding_mask[:, ::stride]
def compute_pred(self, proj_x, label_embs):
if self.target_glu:
label_embs = self.target_glu(label_embs)
x = F.normalize(proj_x.float(), dim=-1) # (S, D)
label_embs = F.normalize(label_embs.float(), dim=-1) # (C, D)
logits = torch.matmul(x, label_embs.T).type_as(proj_x) # (S, C)
logits /= self.logit_temp
return logits
def compute_hubert_logits(self, x, target, proj, label_embs, padding_mask, mask_indices):
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = proj(x[masked_indices])
logit_m_list = [(self.compute_pred(proj_x_m, label_embs), target[masked_indices])]
else:
logit_m_list = [None]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = proj(x[nomask_indices])
logit_u_list = [(self.compute_pred(proj_x_u, label_embs), target[nomask_indices])]
else:
logit_u_list = [None]
return logit_m_list, logit_u_list
def compute_ce_logits(self, x, target, proj, padding_mask, mask_indices):
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
logit_m_list = [(proj(x[masked_indices]), target[masked_indices])]
else:
logit_m_list = [None]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
logit_u_list = [(proj(x[nomask_indices]), target[nomask_indices])]
else:
logit_u_list = [None]
return logit_m_list, logit_u_list
def convert_embeddings(self,
x,
padding_mask,
target=None,
mask_indices=None,
mix_with_unit=False,
use_pred_unit=False,
l2_embedding=False,
remask=False
):
"""
1. Mix with units if needed (default: True)
2. Prepare for unit_encoder inputs
Inputs:
x, (B, T, D)
Return:
src_tokens, (B, T)
soft_embeddings, (B, T, D)
l2_loss, a loss
"""
soft_embeddings = self.final_proj_list[0](x) if x.size(-1) == self.final_dim else x
if padding_mask is None:
padding_mask = soft_embeddings.new_zeros(soft_embeddings.size(0), soft_embeddings.size(1), dtype=torch.long)
if use_pred_unit:
src_tokens = self.compute_pred(self.final_proj_list[0](x), self.label_embs_list[0]).argmax(dim=-1)
src_tokens[padding_mask] = self.padding_idx
elif target is not None:
src_tokens = target
else:
src_tokens = padding_mask.long()
if l2_embedding | mix_with_unit:
unit_embeddings = self.unit_embed_tokens(src_tokens) # (B, T, D)
l2_loss = 0
if l2_embedding:
if mask_indices is not None:
l2_loss = (soft_embeddings - unit_embeddings)[mask_indices].float().pow(2).mean(dim=-1)
scale = unit_embeddings[mask_indices].float().pow(2).sum(dim=-1)
else:
l2_loss = (soft_embeddings - unit_embeddings).float().pow(2).mean(dim=-1)
scale = unit_embeddings.float().pow(2).sum(dim=-1)
l2_loss = (l2_loss / scale).mean()
if mix_with_unit:
B, T, D = x.shape
selected_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob / 2,
self.mask_length // 2,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
selected_indices = torch.from_numpy(selected_indices).to(x.device)
if mask_indices is not None:
if remask:
remask_indices = torch.logical_and(selected_indices, mask_indices)
soft_embeddings[remask_indices] = self.mask_emb
swap_indices = torch.logical_and(selected_indices, ~mask_indices)
else:
swap_indices = selected_indices
soft_embeddings[swap_indices] = unit_embeddings[swap_indices]
soft_embeddings = soft_embeddings * (1 - padding_mask.unsqueeze(-1).type_as(x))
return src_tokens, soft_embeddings, l2_loss
def forward(
self,
source: torch.Tensor = None,
src_tokens: torch.Tensor = None,
src_lengths: torch.Tensor = None,
prev_output_tokens: torch.Tensor = None,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
assert source is not None or src_tokens is not None
if source is not None:
return self.forward_speech(
source=source,
target_list=target_list,
padding_mask=padding_mask,
mask=mask,
features_only=features_only,
output_layer=output_layer,
)
else:
return self.forward_text(
src_tokens=src_tokens,
src_lengths=src_lengths,
prev_output_tokens=prev_output_tokens,
mask=self.mask_u2t,
features_only=features_only,
output_layer=output_layer,
)
def forward_speech(
self,
source: torch.Tensor = None,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
logit_m_list, logit_u_list = self.compute_hubert_logits(
x,
target_list[0],
self.final_proj_list[0],
self.label_embs_list[0],
padding_mask,
mask_indices,
)
result = {
"logit_m_list": logit_m_list,
"logit_u_list": logit_u_list,
"padding_mask": padding_mask,
"features_pen": features_pen,
}
if self.add_unit_encoder:
src_tokens, x_emb, l2_loss = self.convert_embeddings(
x,
padding_mask, target_list[0],
mask_indices=mask_indices,
mix_with_unit=self.mix_with_unit,
use_pred_unit=self.use_pred_unit,
l2_embedding=self.l2_embedding,
)
encoder_out = self.unit_encoder(src_tokens, token_embeddings=x_emb)
result['encoder_out'] = encoder_out['encoder_out'] # [(T, B, D)]
result['encoder_padding_mask'] = encoder_out['encoder_padding_mask'] # [(B, T)]
if self.l2_embedding:
result['embedding_l2_loss'] = l2_loss
code_logit_m_list, code_logit_u_list = self.compute_ce_logits(
encoder_out['encoder_out'][0].transpose(0, 1), # -> (B, T, C)
target_list[-1],
self.final_proj_list[1],
padding_mask,
mask_indices,
)
result['logit_m_list'] += code_logit_m_list
result['logit_u_list'] += code_logit_u_list
return result
def forward_text(
self,
src_tokens: torch.Tensor = None,
src_lengths: torch.Tensor = None,
prev_output_tokens: torch.Tensor = None,
target_list: Optional[List[torch.Tensor]] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
assert self.add_unit_encoder, f"Can not forward unit-text branch without unit_encoder!"
padding_mask = src_tokens == self.padding_idx
unit_embeddings = self.unit_embed_tokens(src_tokens)
if mask:
unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, [src_tokens])
encoder_out = self.unit_encoder(
src_tokens,
token_embeddings=unit_embeddings,
return_all_hiddens=output_layer is not None,
)
result = {}
result["encoder_out"] = encoder_out["encoder_out"]
result["encoder_states"] = encoder_out["encoder_states"]
result["padding_mask"] = padding_mask
if self.add_text_ctc:
result["encoder_out_ctc"] = [self.unit_encoder_ctc_head(x) for x in encoder_out['encoder_out']]
result["encoder_padding_mask"] = [
self.downsample_ctc_padding_mask(padding_mask) for padding_mask in encoder_out['encoder_padding_mask']
]
if features_only:
return result
if self.add_decoder:
assert prev_output_tokens is not None
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out,
)
result['decoder_out'] = decoder_out
return result
def forward_mum(self, src_tokens, target, mask=True):
target_list = [target]
padding_mask = src_tokens.eq(self.unit_encoder.padding_idx)
unit_embeddings = self.unit_embed_tokens(src_tokens)
if mask:
unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, target_list)
else:
### If already applied mask on src_tokens, then the target_list should contains many padding_idx
mask_indices = target_list[-1] != self.padding_idx
unit_embeddings[mask_indices] = self.mask_emb
encoder_out = self.unit_encoder(
src_tokens,
token_embeddings=unit_embeddings,
)
code_logit_m_list, code_logit_u_list = self.compute_ce_logits(
encoder_out["encoder_out"][0].transpose(0, 1),
target_list[-1],
self.final_proj_list[1],
padding_mask,
mask_indices,
)
result = {}
result["logit_m_list"] = code_logit_m_list
result["logit_u_list"] = code_logit_u_list
result["padding_mask"] = padding_mask
return result
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Extract encoder features for only speech input"""
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
x = res["x"] # B x T x D
padding_mask = res["padding_mask"]
if self.add_unit_encoder:
src_tokens, x, _ = self.convert_embeddings(
x,
padding_mask,
mix_with_unit=False,
use_pred_unit=False,
)
encoder_out = self.unit_encoder(
src_tokens,
token_embeddings=x,
return_all_hiddens=output_layer is not None
)
res["x"] = encoder_out['encoder_out'][0].transpose(0, 1) # (B, T, D)
feature = res["features"] if ret_conv else res["x"]
if output_layer is not None:
feature = encoder_out['encoder_states']
return feature, padding_mask
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x[0].float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
targets_list = [x[1].long() for x in logits_list if x is not None]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
if "embedding_l2_loss" in net_output:
extra_losses.append(net_output["embedding_l2_loss"])
names.append("embedding_l2_loss")
return extra_losses, names
def remove_pretraining_modules(self, step2=False):
self.target_glu = None
def load_checkpoint(self, checkpoint: str):
if not PathManager.exists(checkpoint):
raise IOError("Model file not found: {}".format(checkpoint))
state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
return state
class Rotate3D(nn.Module):
"""
(T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D)
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(1, 2, 0)
# ----------------------------------------------------------------------------
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import contextlib
import torch
from dataclasses import dataclass, field
from fairseq import utils
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.fairseq_encoder import FairseqEncoder
from fairseq.models.hubert import HubertAsrConfig, HubertEncoder
from fairseq.tasks import FairseqTask
@dataclass
class SpeechUTASRConfig(HubertAsrConfig):
add_decoder: bool = field(
default=True,
metadata={"help": "add decoder for fine-tune"},
)
@register_model("speechut_asr", dataclass=SpeechUTASRConfig)
class SpeechUTASR(BaseFairseqModel):
"""
A encoder-ctc-decoder model if cfg.add_decoder is True, or a encoder-ctc model
"""
def __init__(self, cfg: SpeechUTASRConfig, encoder: FairseqEncoder):
super().__init__()
self.cfg = cfg
self.encoder = encoder
if not cfg.add_decoder:
self.encoder.w2v_model.decoder = None
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: SpeechUTASRConfig, task: FairseqTask):
"""Build a new model instance."""
encoder = SpeechUTEncoder(cfg, task)
return cls(cfg, encoder)
def forward(self, source, padding_mask, prev_output_tokens, **kwargs):
encoder_out = self.encoder(source, padding_mask, **kwargs)
x = self.encoder.final_dropout(encoder_out['encoder_out'][0]) # (T, B, C)
if self.encoder.proj:
x = self.encoder.proj(x)
if self.encoder.conv_ctc_proj:
padding_mask = self.encoder.w2v_model.downsample_ctc_padding_mask(encoder_out["encoder_padding_mask"][0])
else:
padding_mask = encoder_out["encoder_padding_mask"]
decoder_out = self.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs
) if self.cfg.add_decoder else None
return {
"encoder_out_ctc": x, # (T, B, C), for CTC loss
"padding_mask": padding_mask, # (B, T), for CTC loss
"decoder_out": decoder_out, # for ED loss
}
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def get_logits(self, net_output):
"""For CTC decoding"""
logits = net_output["encoder_out"]
padding = net_output["encoder_padding_mask"]
if padding is not None and padding.any():
padding = padding.T
logits[padding][..., 0] = 0
logits[padding][..., 1:] = float("-inf")
return logits
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""For 1) computing CTC loss, 2) decoder decoding."""
if "encoder_out_ctc" in net_output:
logits = net_output["encoder_out_ctc"]
else:
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
if isinstance(logits, list):
logits = logits[0]
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
@property
def decoder(self):
return self.encoder.w2v_model.decoder
class SpeechUTEncoder(HubertEncoder):
"""
Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
1. make it compatible with encoder-decoder model
"""
def __init__(self, cfg: HubertAsrConfig, task):
super().__init__(cfg, task)
if (task.target_dictionary is not None) and (
hasattr(self.w2v_model, "unit_encoder_ctc_head")
):
self.proj = self.w2v_model.unit_encoder_ctc_head
self.conv_ctc_proj = True
else:
self.conv_ctc_proj = False
def forward(self, source, padding_mask, tbc=True, **kwargs):
w2v_args = {
"source": source,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
}
ft = self.freeze_finetune_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
x, padding_mask = self.w2v_model.extract_features(**w2v_args)
if tbc:
# B x T x C -> T x B x C
x = x.transpose(0, 1)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [padding_mask], # B x T
}
def forward_torchscript(self, net_input):
"""A TorchScript-compatible version of forward.
Forward the encoder out.
"""
x, padding_mask = self.w2v_model.extract_features(**net_input, mask=False)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_out = {
"encoder_out" : [x],
"encoder_padding_mask" : [padding_mask],
}
if self.proj:
x = self.proj(x)
encoder_out["encoder_out_ctc"] = x
return encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = [
x.index_select(1, new_order) for x in encoder_out["encoder_out"]
]
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = [
x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]
]
return encoder_out
# ----------------------------------------------------------------------------
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import logging
import contextlib
import torch
import torch.nn as nn
from argparse import Namespace
from dataclasses import dataclass
from typing import Any
from fairseq import checkpoint_utils, tasks
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.fairseq_encoder import FairseqEncoder
from fairseq.tasks import FairseqTask
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models.hubert import HubertAsrConfig
logger = logging.getLogger(__name__)
@dataclass
class SpeechUTS2TConfig(HubertAsrConfig):
### the following config is only for the compatibility to fairseq speech_to_text task
input_feat_per_channel: Any = None
input_channels: Any = None
speaker_to_id: Any = None
@register_model("speechut_st_legacy", dataclass=SpeechUTS2TConfig)
class SpeechUTS2T(BaseFairseqModel):
"""An encoder-decoder model."""
def __init__(self, cfg: SpeechUTS2TConfig, encoder: FairseqEncoder):
super().__init__()
self.cfg = cfg
self.encoder = encoder
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: SpeechUTS2TConfig, task: FairseqTask):
"""Build a new model instance."""
encoder = SpeechUTEncoder(cfg, task)
return cls(cfg, encoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths, **kwargs)
decoder_out = self.encoder.w2v_model.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.encoder.w2v_model.decoder(prev_output_tokens, **kwargs)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""For decoder decoding."""
return self.encoder.w2v_model.decoder.get_normalized_probs(net_output, log_probs, sample)
@property
def decoder(self):
return self.encoder.w2v_model.decoder
class SpeechUTEncoder(FairseqEncoder):
"""
Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
1. make it compatible with fairseq speech_to_text task
2. make it compatible with encoder-decoder model
"""
def __init__(self, cfg: SpeechUTS2TConfig, task):
self.apply_mask = cfg.apply_mask
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
}
if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
cfg.w2v_args = w2v_args
else:
state = None
w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
assert task.data_cfg.standardize_audio() == w2v_args.task.normalize, (
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
pretrain_task = tasks.setup_task(w2v_args.task, load_local_states=False)
assert state is not None and "task_state" in state, f"the stored dictionaries not found in checkpoint!"
# This will load the stored "dictionaries" object
pretrain_task.load_state_dict(state["task_state"])
model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True)
if state is not None and not cfg.no_pretrained_weights:
try:
model.load_state_dict(state["model"], strict=True)
except Exception as e:
logger.warn(e)
model.load_state_dict(state["model"], strict=False)
model.remove_pretraining_modules()
super().__init__(pretrain_task.source_dictionary)
d = w2v_args.model.encoder_embed_dim
self.w2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.num_updates = 0
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def forward(self, src_tokens=None, src_lengths=None, **kwargs):
w2v_args = {
"source": src_tokens,
"padding_mask": lengths_to_padding_mask(src_lengths),
"mask": self.apply_mask and self.training,
}
ft = self.freeze_finetune_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
x, padding_mask = self.w2v_model.extract_features(**w2v_args)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [padding_mask], # B x T
"padding_mask": [padding_mask],
}
def forward_torchscript(self, net_input):
"""A TorchScript-compatible version of forward.
Forward the encoder out.
"""
_net_input = {
"source": net_input["src_tokens"],
"padding_mask": lengths_to_padding_mask(net_input["src_lengths"]),
"mask": False,
}
x, padding_mask = self.w2v_model.extract_features(**_net_input)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_out = {
"encoder_out" : [x],
"encoder_padding_mask" : [padding_mask],
}
return encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = [
x.index_select(1, new_order) for x in encoder_out["encoder_out"]
]
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = [
x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]
]
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return None
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from fairseq.models import (
register_model_architecture,
)
from fairseq.models.transformer_lm import base_lm_architecture
@register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
def transformer_lm_t5(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
args.decoder_layers = getattr(args, "decoder_layers", 20)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
from .learned_positional_embedding import LearnedPositionalEmbedding
from .multihead_attention import MultiheadAttention
from .relative_pos_enc import RelativePositionalEncoding
from .transformer_layer import TransformerEncoderLayerBase, TransformerDecoderLayerBase
from .w2v_encoder import TransformerEncoder, TransformerSentenceEncoderLayer
from .transformer_encoder import TransformerEncoderBase
from .transformer_decoder import TransformerDecoderScriptable, TransformerDecoderBaseScriptable
__all__ = [
"MultiheadAttention",
"RelativePositionalEncoding",
"LearnedPositionalEmbedding",
"TransformerEncoderLayerBase",
"TransformerDecoderLayerBase",
"TransformerEncoder",
"TransformerSentenceEncoderLayer",
"TransformerEncoderBase",
"TransformerDecoderScriptable",
"TransformerDecoderBaseScriptable",
]
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import numpy as np
import six
class CTCPrefixScore(object):
"""Compute CTC label sequence scores
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the probablities of multiple labels
simultaneously
"""
def __init__(self, x, blank, eos, xp):
self.xp = xp
self.logzero = -10000000000.0
self.blank = blank
self.eos = eos
self.input_length = len(x)
self.x = x
def initial_state(self):
"""Obtain an initial CTC state
:return: CTC state
"""
# initial CTC state is made of a frame x 2 tensor that corresponds to
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# superscripts n and b (non-blank and blank), respectively.
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
r[0, 1] = self.x[0, self.blank]
for i in six.moves.range(1, self.input_length):
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
return r
def __call__(self, y, cs, r_prev):
"""Compute CTC prefix scores for next labels
:param y : prefix label sequence
:param cs : array of next labels
:param r_prev: previous CTC state
:return ctc_scores, ctc_states
"""
# initialize CTC states
output_length = len(y) - 1 # ignore sos
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
# that corresponds to r_t^n(h) and r_t^b(h).
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
xs = self.x[:, cs]
if output_length == 0:
r[0, 0] = xs[0]
r[0, 1] = self.logzero
else:
r[output_length - 1] = self.logzero
# prepare forward probabilities for the last label
r_sum = self.xp.logaddexp(
r_prev[:, 0], r_prev[:, 1]
) # log(r_t^n(g) + r_t^b(g))
last = y[-1]
if output_length > 0 and last in cs:
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
for i in six.moves.range(len(cs)):
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
else:
log_phi = r_sum
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
# and log prefix probabilities log(psi)
start = max(output_length, 1)
log_psi = r[start - 1, 0]
for t in six.moves.range(start, self.input_length):
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
r[t, 1] = (
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
)
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
# get P(...eos|X) that ends with the prefix itself
eos_pos = self.xp.where(cs == self.eos)[0]
if len(eos_pos) > 0:
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
# exclude blank probs
blank_pos = self.xp.where(cs == self.blank)[0]
if len(blank_pos) > 0:
log_psi[blank_pos] = self.logzero
# return the log prefix probability and CTC states, where the label axis
# of the CTC states is moved to the first axis to slice it easily
return log_psi, self.xp.rollaxis(r, 2)
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/learned_positional_embedding.py
1. Add clamping if the input length exceeds the max-source-tokens
"""
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from torch import Tensor
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.onnx_trace = False
if self.padding_idx is not None:
self.max_positions = self.num_embeddings - self.padding_idx - 1
else:
self.max_positions = self.num_embeddings
def forward(
self,
input: Tensor,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
positions: Optional[Tensor] = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
assert (positions is None) or (
self.padding_idx is None
), "If positions is pre-computed then padding_idx should not be set."
if positions is None:
if incremental_state is not None:
# positions is the same for every token when decoding a single step
# Without the int() cast, it doesn't work in some cases when exporting to ONNX
positions = torch.zeros(
(1, 1), device=input.device, dtype=input.dtype
).fill_(int(self.padding_idx + input.size(1)))
else:
positions = utils.make_positions(
input, self.padding_idx, onnx_trace=self.onnx_trace
)
positions = torch.clamp(positions, max=self.padding_idx + self.max_positions)
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from torch import Tensor
from fairseq.modules import MultiheadAttention as FairseqMultiheadAttention
class MultiheadAttention(FairseqMultiheadAttention):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
scaling_for_att=1.0
):
super().__init__(
embed_dim,
num_heads,
kdim,
vdim,
dropout,
bias,
add_bias_kv,
add_zero_attn,
self_attention,
encoder_decoder_attention,
q_noise,
qn_block_size,
)
self.scaling_for_att = scaling_for_att
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if (
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
and position_bias is None
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
q *= (1 / self.scaling_for_att)
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
if position_bias is not None: ## first order
## position_bias: [241, 241, 64]
#print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
#print ("reshape_q: ", reshape_q.size())
B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
#print ("B: ", B.size()) ## [241, 492, 241]
#B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
#print ("B 2: ", B.size())
attn_weights += B
attn_weights *= self.scaling_for_att
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if self.scaling_for_att > 1.0:
attn_weights = attn_weights - attn_weights.detach().max(dim=-1, keepdim=True)[0]
if before_softmax:
return attn_weights, v
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import torch
class RelativePositionalEncoding(torch.nn.Module):
def __init__(self, d_model, maxlen=1000, embed_v=False):
super(RelativePositionalEncoding, self).__init__()
self.d_model = d_model
self.maxlen = maxlen
self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
if embed_v:
self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
self.embed_v = embed_v
def forward(self, pos_seq, incremental_state=None):
pos_seq[pos_seq < -self.maxlen] = -self.maxlen
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
pos_seq = pos_seq + self.maxlen
if incremental_state is not None:
pos_seq = pos_seq[-1:]
if self.embed_v:
return self.pe_k(pos_seq), self.pe_v(pos_seq)
else:
return self.pe_k(pos_seq), None
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/transformer/transformer_decoder.py
"""
import math
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import FairseqIncrementalDecoder
from fairseq.models.transformer import TransformerConfig
from fairseq.modules import (
AdaptiveSoftmax,
BaseLayer,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
from speechut.modules import transformer_layer
from speechut.modules import RelativePositionalEncoding
# rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str:
if module_name == "TransformerDecoderBase":
return "TransformerDecoder"
else:
return module_name
class TransformerDecoderBase(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
cfg,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
use_rel_pos_enc=False,
):
self.cfg = cfg
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self._future_mask = torch.empty(0)
self.dropout_module = FairseqDropout(
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
)
self.decoder_layerdrop = cfg.decoder.layerdrop
self.share_input_output_embed = cfg.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = cfg.decoder.embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = cfg.decoder.output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = cfg.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
cfg.quant_noise.pq,
cfg.quant_noise.pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
self.max_target_positions,
embed_dim,
self.padding_idx,
learned=cfg.decoder.learned_pos,
)
if not cfg.no_token_positional_embeddings
else None
)
if cfg.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
else:
self.layernorm_embedding = None
self.cross_self_attention = cfg.cross_self_attention
if self.decoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.use_rel_pos_enc = use_rel_pos_enc
self.layers.extend(
[
self.build_decoder_layer(cfg, no_encoder_attn)
for _ in range(cfg.decoder.layers)
]
)
self.num_layers = len(self.layers)
if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
else:
self.layer_norm = None
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
self.output_projection = output_projection
if self.output_projection is None:
self.build_output_projection(cfg, dictionary, embed_tokens)
if self.use_rel_pos_enc:
self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.decoder.attention_heads, 24)
def build_output_projection(self, cfg, dictionary, embed_tokens):
if cfg.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int),
dropout=cfg.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None,
factor=cfg.adaptive_softmax_factor,
tie_proj=cfg.tie_adaptive_proj,
)
elif self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = self.embed_tokens.weight
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
)
num_base_layers = cfg.base_layers
for i in range(num_base_layers):
self.layers.insert(
((i + 1) * cfg.decoder.layers) // (num_base_layers + 1),
BaseLayer(cfg),
)
def build_decoder_layer(self, cfg, no_encoder_attn=False):
layer = transformer_layer.TransformerDecoderLayerBase(cfg, no_encoder_attn, has_relative_attention_bias=self.use_rel_pos_enc)
checkpoint = cfg.checkpoint_activations
if checkpoint:
offload_to_cpu = cfg.offload_activations
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
return self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. A copy of
this function is made to be used in the subclass instead.
"""
def extract_features_scriptable(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
bs, slen = prev_output_tokens.size()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
enc: Optional[Tensor] = None
padding_mask: Optional[Tensor] = None
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
enc = encoder_out["encoder_out"][0]
assert (
enc.size()[1] == bs
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
padding_mask = encoder_out["encoder_padding_mask"][0]
# embed positions
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.use_rel_pos_enc:
pos_seq = torch.arange(0, slen).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, _ = self.pos_emb(pos_seq, incremental_state)
else:
pos_k = None
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
# decoder layers
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn, _ = layer(
x,
enc,
padding_mask,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
pos_bias=pos_k,
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states}
def output_layer(self, features):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
return self.output_projection(features)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if (
self._future_mask.size(0) == 0
or (not self._future_mask.device == tensor.device)
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
)
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"
else:
embed_out_key = f"{name}.embed_out"
if embed_out_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_out_key
]
if not self.share_input_output_embed:
del state_dict[embed_out_key]
for i in range(self.num_layers):
# update layer norms
layer_norm_map = {
"0": "self_attn_layer_norm",
"1": "encoder_attn_layer_norm",
"2": "final_layer_norm",
}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
if k in state_dict:
state_dict[
"{}.layers.{}.{}.{}".format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
class TransformerDecoderBaseScriptable(TransformerDecoderBase):
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
# call scriptable method from parent class
x, _ = self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
return x, None
class TransformerDecoder(TransformerDecoderBase):
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
self.args = args
super().__init__(
TransformerConfig.from_namespace(args),
dictionary,
embed_tokens,
no_encoder_attn=no_encoder_attn,
output_projection=output_projection,
use_rel_pos_enc=getattr(args, "use_rel_pos_enc", False),
)
def build_output_projection(self, args, dictionary, embed_tokens):
super().build_output_projection(
TransformerConfig.from_namespace(args), dictionary, embed_tokens
)
def build_decoder_layer(self, args, no_encoder_attn=False):
return super().build_decoder_layer(
TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn
)
class TransformerDecoderScriptable(TransformerDecoder):
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
# call scriptable method from parent class
x, _ = self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
return x, None
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import math
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import FairseqEncoder
from fairseq.modules import (
FairseqDropout,
LayerDropModuleList,
LayerNorm,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
from fairseq.models.transformer import (
TransformerConfig,
)
from speechut.modules import transformer_layer, LearnedPositionalEmbedding
from speechut.modules import RelativePositionalEncoding
# rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str:
if module_name == "TransformerEncoderBase":
return "TransformerEncoder"
else:
return module_name
class TransformerEncoderBase(FairseqEncoder):
"""
Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, cfg, dictionary, embed_tokens, use_rel_pos_enc=False, scaling_for_att=1.0):
self.cfg = cfg
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self.dropout_module = FairseqDropout(
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
)
self.encoder_layerdrop = cfg.encoder.layerdrop
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = cfg.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
cfg.max_source_positions,
embed_dim,
self.padding_idx,
learned=cfg.encoder.learned_pos,
)
if not cfg.no_token_positional_embeddings
else None
)
if cfg.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
else:
self.layernorm_embedding = None
if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
cfg.quant_noise.pq,
cfg.quant_noise.pq_block_size,
)
else:
self.quant_noise = None
if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.use_rel_pos_enc = use_rel_pos_enc
self.scaling_for_att = scaling_for_att
self.layers.extend(
[self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)]
)
self.num_layers = len(self.layers)
if cfg.encoder.normalize_before:
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
else:
self.layer_norm = None
if self.use_rel_pos_enc:
self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.encoder.attention_heads, 160)
def build_encoder_layer(self, cfg):
layer = transformer_layer.TransformerEncoderLayerBase(cfg, has_relative_attention_bias=self.use_rel_pos_enc, scaling_for_att=self.scaling_for_att)
checkpoint = cfg.checkpoint_activations
if checkpoint:
offload_to_cpu = cfg.offload_activations
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward_embedding(
self, src_tokens, token_embedding: Optional[torch.Tensor] = None
):
# embed tokens and positions
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
if self.quant_noise is not None:
x = self.quant_noise(x)
return x, embed
def forward(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
uniformity_layers: Optional[List[int]] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return self.forward_scriptable(
src_tokens, src_lengths, return_all_hiddens, token_embeddings, uniformity_layers
)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
uniformity_layers: Optional[List[int]] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any()
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
# account for padding while computing the representation
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.use_rel_pos_enc:
x_len = x.shape[0]
pos_seq = torch.arange(0, x_len).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, pos_v = self.pos_emb(pos_seq)
else:
pos_k = None
encoder_states = []
uniformity_hiddens = []
if return_all_hiddens:
encoder_states.append(x)
if uniformity_layers is not None and 0 in uniformity_layers:
x = F.normalize(x.float(), dim=-1).type_as(x)
uniformity_hiddens.append(x)
# encoder layers
for i, layer in enumerate(self.layers):
x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None,
pos_bias=pos_k,
)
if uniformity_layers is not None and i+1 in uniformity_layers:
x = F.normalize(x.float(), dim=-1).type_as(x)
uniformity_hiddens.append(x)
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
src_lengths = (
src_tokens.ne(self.padding_idx)
.sum(dim=1, dtype=torch.int32)
.reshape(-1, 1)
.contiguous()
)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"uniformity_hiddens": uniformity_hiddens, # List[T x B x C]
"src_tokens": [],
"src_lengths": [src_lengths],
}
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
encoder_out["encoder_embedding"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
src_tokens = []
else:
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["src_lengths"]) == 0:
src_lengths = []
else:
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": src_tokens, # B x T
"src_lengths": src_lengths, # B x 1
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
print("deleting {0}".format(weights_key))
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(
state_dict, "{}.layers.{}".format(name, i)
)
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
class TransformerEncoder(TransformerEncoderBase):
def __init__(self, args, dictionary, embed_tokens):
self.args = args
super().__init__(
TransformerConfig.from_namespace(args),
dictionary,
embed_tokens,
use_rel_pos_enc=getattr(args, "use_rel_pos_enc", False),
scaling_for_att=getattr(args, "scaling_for_att", 1.0),
)
def build_encoder_layer(self, args):
return super().build_encoder_layer(
TransformerConfig.from_namespace(args),
)
def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
learned: bool = False,
):
if learned:
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
# TODO: The right place for this offset would be inside
# LearnedPositionalEmbedding. Move this there for a cleaner implementation.
if padding_idx is not None:
num_embeddings = num_embeddings + padding_idx + 1
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim,
padding_idx,
init_size=num_embeddings + padding_idx + 1,
)
return m
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/transformer_layer.py
https://github.com/microsoft/SpeechT5/blob/main/Speech2C/speech2c/models/modules/transformer_decoder_layer.py
"""
from typing import Dict, List, Optional
import torch
from torch import Tensor
from fairseq.modules import LayerNorm
from fairseq.modules.transformer_layer import TransformerEncoderLayerBase as FairseqTransformerEncoderLayerBase
from fairseq.modules.transformer_layer import TransformerDecoderLayerBase as FairseqTransformerDecoderLayerBase
from speechut.modules import MultiheadAttention
class TransformerEncoderLayerBase(FairseqTransformerEncoderLayerBase):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.encoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, cfg, has_relative_attention_bias=False, scaling_for_att=1.0):
self.scaling_for_att = scaling_for_att
super().__init__(cfg)
if has_relative_attention_bias:
self.norm_k = LayerNorm(self.embed_dim // cfg.encoder.attention_heads)
def build_self_attention(self, embed_dim, cfg, scaling_for_att=1.0):
return MultiheadAttention(
embed_dim,
cfg.encoder.attention_heads,
dropout=cfg.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
scaling_for_att=self.scaling_for_att,
)
def forward(
self,
x,
encoder_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor] = None,
pos_bias=None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if pos_bias is not None:
pos_bias = self.norm_k(pos_bias)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
position_bias=pos_bias,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x
class TransformerDecoderLayerBase(FairseqTransformerDecoderLayerBase):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.decoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False, scaling_for_att=1.0,
):
self.scaling_for_att = scaling_for_att
super().__init__(cfg,
no_encoder_attn,
add_bias_kv,
add_zero_attn,
)
if has_relative_attention_bias:
self.norm_k = LayerNorm(self.embed_dim // cfg.decoder.attention_heads)
def build_self_attention(
self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
dropout=cfg.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not cfg.cross_self_attention,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
scaling_for_att=self.scaling_for_att,
)
def build_encoder_attention(self, embed_dim, cfg):
return MultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
kdim=cfg.encoder.embed_dim,
vdim=cfg.encoder.embed_dim,
dropout=cfg.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
scaling_for_att=self.scaling_for_att,
)
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
pos_bias=None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if pos_bias is not None:
pos_bias = self.norm_k(pos_bias)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
if self.c_attn is not None:
tgt_len, bsz = x.size(0), x.size(1)
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
x = x.reshape(tgt_len, bsz, self.embed_dim)
if self.attn_ln is not None:
x = self.attn_ln(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = self.dropout_module(x)
if self.w_resid is not None:
residual = torch.mul(self.w_resid, residual)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
wav2vec encoder adding relitive position bias, modified from
https://github.com/microsoft/SpeechT5/blob/main/Speech2C/speech2c/models/modules/transformer_encoder.py
https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.dataclass import ChoiceEnum
from fairseq.modules import (
LayerNorm,
SamePad,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import index_put
from fairseq.distributed import fsdp_wrap
from fairseq.models.wav2vec.utils import pad_to_multiple
## reload multi-head attition with rel-pos-bias
from fairseq.models.wav2vec.wav2vec2 import TransformerEncoder as W2vTransformerEncoder
from speechut.modules import RelativePositionalEncoding
from speechut.modules import MultiheadAttention
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
class TransformerEncoder(W2vTransformerEncoder):
def __init__(self, args):
super().__init__(args)
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.required_seq_len_multiple = args.required_seq_len_multiple
self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False)
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
layers = []
for _ in range(args.encoder_layers):
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
has_relative_attention_bias=self.use_rel_pos_enc,
)
if args.checkpoint_activations:
layer = fsdp_wrap(layer)
layer = checkpoint_wrapper(layer)
layers.append(layer)
self.layers = nn.ModuleList(layers)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
if self.use_rel_pos_enc:
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160)
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, tgt_layer=None):
if padding_mask is not None:
x = index_put(x, padding_mask, 0)
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0 and padding_mask is None:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = pad_to_multiple(
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.use_rel_pos_enc:
x_len = x.shape[0]
pos_seq = torch.arange(0, x_len).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, pos_v = self.pos_emb(pos_seq)
else:
pos_k = None
layer_results = []
r = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k)
if tgt_layer is not None:
# unpad if needed
if pad_length > 0:
layer_results.append(
(
x[:-pad_length],
z[:, :-pad_length, :-pad_length]
if z is not None
else z,
)
)
else:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# undo paddding
if pad_length > 0:
x = x[:, :-pad_length]
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
has_relative_attention_bias: bool = False,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)
if has_relative_attention_bias:
self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
pos_bias=None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
if pos_bias is not None:
pos_bias = self.norm_k(pos_bias)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn
# ####################################
# SpeechUT Base model #
# ####################################
[ $# -lt 2 ] && echo "Usage: $0 <data_dir> <text_data_dir> [mount=${PWD}] [world_size=32] [update_freq=1]" && exit 1
[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
DATA_DIR=$1
TEXT_DATA_DIR=$2
mount=$3
world_size=$4
update_freq=$5
[ -z $mount ] && mount=${PWD}
[ -z $world_size ] && world_size=32
[ -z $update_freq ] && update_freq=1
CODE_ROOT=${PWD}
MODEL_DIR="${mount}/exp/pretrain/base_speechut4asr_${world_size}gpu_${update_freq}accum"
[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
--config-dir $CODE_ROOT/speechut/config/pretrain \
--config-name speechut_base_librispeech \
common.user_dir=$CODE_ROOT/speechut \
\
task.labels='["km"]' \
model.label_rate=50 \
task.data=$DATA_DIR \
task.label_dir=$DATA_DIR \
task.text_cfg.text_data=$TEXT_DATA_DIR \
\
dataset.train_subset=\"train_960+pseudo_libritext.kmu-ltr+merge_960.kmu-none\" \
dataset.valid_subset=\"dev_clean+dev.kmu-ltr+dev.kmu-none\" \
dataset.num_workers=0 \
dataset.max_tokens=1400000 \
distributed_training.distributed_world_size=${world_size} \
optimization.update_freq=[${update_freq}] \
\
common.tensorboard_logdir=$MODEL_DIR \
checkpoint.save_dir=$MODEL_DIR \
hydra.run.dir=$MODEL_DIR \
hydra.job.name=base_speechut4asr_${world_size}gpu_${update_freq}accum
# ####################################
# SpeechUT Base model #
# ####################################
[ $# -lt 3 ] && echo "Usage: $0 <data_dir> <text_data_dir> <lang=de/es> [mount=${PWD}] [world_size=32] [update_freq=1]" && exit 1
[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
DATA_DIR=$1
TEXT_DATA_DIR=$2
lang=$3
mount=$4
world_size=$5
update_freq=$6
[ -z $mount ] && mount=${PWD}
[ -z $world_size ] && world_size=32
[ -z $update_freq ] && update_freq=1
CODE_ROOT=${PWD}
MODEL_DIR="${mount}/exp/pretrain/base_speechut4en${lang}_${world_size}gpu_${update_freq}accum"
[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
--config-dir $CODE_ROOT/speechut/config/pretrain \
--config-name speechut_base_librispeech \
common.user_dir=$CODE_ROOT/speechut \
\
task.labels='["km"]' \
model.label_rate=50 \
task.data=$DATA_DIR \
task.label_dir=$DATA_DIR \
task.text_cfg.text_data=$TEXT_DATA_DIR \
\
model.add_text_ctc=false \
model.text_transformer.share_decoder_input_output_embed=true \
criterion.u2t_ed_weight=1.0 \
criterion.u2t_ctc_weight=0 \
\
dataset.train_subset=\"train_960,mustcuns_${lang}+pseudo_wmt_en${lang}.kmu-spm+train_960.kmu-none,mustcuns_${lang}.kmu-none\" \
dataset.valid_subset=\"dev_clean+pseudo_valid.kmu-spm+dev.kmu-none\" \
dataset.num_workers=0 \
dataset.max_tokens=1400000 \
distributed_training.distributed_world_size=${world_size} \
optimization.update_freq=[${update_freq}] \
\
common.tensorboard_logdir=$MODEL_DIR \
checkpoint.save_dir=$MODEL_DIR \
hydra.run.dir=$MODEL_DIR \
hydra.job.name=base_speechut4en${lang}_${world_size}gpu_${update_freq}accum
# ####################################
# SpeechUT Base model #
# ####################################
[ $# -lt 3 ] && echo "Usage: $0 <data_dir> <text_data_dir> [lang=fr] [mount=${PWD}] [world_size=32] [update_freq=1]" && exit 1
[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
DATA_DIR=$1
TEXT_DATA_DIR=$2
lang=$3
mount=$4
world_size=$5
update_freq=$6
[ -z $lang ] && lang=fr
[ -z $mount ] && mount=${PWD}
[ -z $world_size ] && world_size=32
[ -z $update_freq ] && update_freq=1
CODE_ROOT=${PWD}
MODEL_DIR="${mount}/exp/pretrain/base_speechut4en${lang}_${world_size}gpu_${update_freq}accum"
[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
--config-dir $CODE_ROOT/speechut/config/pretrain \
--config-name speechut_base_librispeech \
common.user_dir=$CODE_ROOT/speechut \
\
task.labels='["km"]' \
model.label_rate=50 \
task.data=$DATA_DIR \
task.label_dir=$DATA_DIR \
task.text_cfg.text_data=$TEXT_DATA_DIR \
\
model.add_text_ctc=false \
criterion.u2t_ed_weight=1.0 \
criterion.u2t_ctc_weight=0 \
\
dataset.train_subset=\"train_960,pretrain_mustc+pseudo_wmt14_enfr.kmu-spm+train_960.kmu-none,pretrain_mustc.kmu-none\" \
dataset.valid_subset=\"dev_clean+pseudo_valid.kmu-spm+dev.kmu-none\" \
dataset.num_workers=0 \
dataset.max_tokens=1400000 \
optimization.max_update=600000 \
distributed_training.distributed_world_size=${world_size} \
optimization.update_freq=[${update_freq}] \
\
common.tensorboard_logdir=$MODEL_DIR \
checkpoint.save_dir=$MODEL_DIR \
hydra.run.dir=$MODEL_DIR \
hydra.job.name=base_speechut4en${lang}_${world_size}gpu_${update_freq}accum
# ####################################
# SpeechUT Large model #
# ####################################
[ $# -lt 2 ] && echo "Usage: $0 <data_dir> <text_data_dir> [mount=${PWD}] [world_size=32] [update_freq=4]" && exit 1
[ ${PWD##*/} != SpeechUT ] && echo "Error: dir not match! Switch to SpeechUT/ and run it again!" && exit 1
DATA_DIR=$1
TEXT_DATA_DIR=$2
mount=$3
world_size=$4
update_freq=$5
[ -z $mount ] && mount=${PWD}
[ -z $world_size ] && world_size=32
[ -z $update_freq ] && update_freq=4
CODE_ROOT=${PWD}
MODEL_DIR="${mount}/exp/pretrain/large_speechut4asr_${world_size}gpu_${update_freq}accum"
[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
--config-dir $CODE_ROOT/speechut/config/pretrain \
--config-name speechut_large_librilight \
common.user_dir=$CODE_ROOT/speechut \
\
task.labels='["km"]' \
model.label_rate=50 \
task.data=$DATA_DIR \
task.label_dir=$DATA_DIR \
task.text_cfg.text_data=$TEXT_DATA_DIR \
\
dataset.train_subset=\"train_small+pseudo_libritext.kmu-ltr\" \
dataset.valid_subset=\"dev_clean+dev.kmu-ltr\" \
dataset.num_workers=0 \
dataset.max_tokens=900000 \
distributed_training.distributed_world_size=${world_size} \
optimization.update_freq=[${update_freq}] \
\
common.tensorboard_logdir=$MODEL_DIR \
checkpoint.save_dir=$MODEL_DIR \
hydra.run.dir=$MODEL_DIR \
hydra.job.name=large_speechut4asr_${world_size}gpu_${update_freq}accum
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment