Commit 72f5785f authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #505 canceled with stages
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
from fairseq.dataclass import FairseqDataclass
from fairseq.models import (
FairseqIncrementalDecoder,
FairseqLanguageModel,
register_model,
)
from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
logger = logging.getLogger(__name__)
@dataclass
class AdaptiveSpanSmallConfig(FairseqDataclass):
# defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
vocab_size: int = 50
d_model: int = 256
n_head: int = 4
d_inner: int = 1024
n_layer: int = 8
attn_span: int = 1024
dropout: float = 0.0
emb_dropout: float = 0.0
adapt_span_ramp: int = 32
adapt_span_init: float = 0.0
aux_loss_scaler: float = 0.000002
adapt_span_layer: bool = False
@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
class AdaptiveSpanTransformer(FairseqLanguageModel):
@classmethod
def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
return cls(AdaptiveSpanDecoder(cfg, task))
def get_aux_loss(self):
return self.decoder.get_aux_loss()
def get_current_max_span(self):
return self.decoder.get_current_max_span()
def get_current_avg_span(self):
return self.decoder.get_current_avg_span()
class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
def __init__(self, cfg, task):
super().__init__(task.target_dictionary)
self.config = cfg
config = AdaptiveSpanSmallConfig(
vocab_size=len(task.target_dictionary),
d_model=cfg.d_model,
n_head=cfg.n_head,
d_inner=cfg.d_inner,
n_layer=cfg.n_layer,
attn_span=cfg.attn_span,
dropout=cfg.dropout,
emb_dropout=cfg.emb_dropout,
adapt_span_ramp=cfg.adapt_span_ramp,
adapt_span_init=cfg.adapt_span_init,
aux_loss_scaler=cfg.aux_loss_scaler,
adapt_span_layer=cfg.adapt_span_layer,
)
logger.info(config)
self.model = AdaptiveSpanTransformerModel(**config.__dict__)
self._mems = None
def forward(
self,
src_tokens,
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
encoder_out=None,
):
bsz = src_tokens.size(0)
if incremental_state is not None: # used during inference
mems = self.get_incremental_state("mems")
src_tokens = src_tokens[:, -1:] # only keep the most recent token
else:
mems = self._mems
if mems is None:
# first time init
mems = self.init_hid_cache(bsz)
output = self.model(x=src_tokens, h_cache=mems,)
if incremental_state is not None:
self.set_incremental_state(incremental_state, "mems", output[1])
else:
self._mems = output[1]
return (output[0],)
def max_positions(self):
return self.config.attn_span
def init_hid_cache(self, batch_sz):
hid = []
for layer in self.model.layers:
param = next(self.model.parameters())
h = torch.zeros(
batch_sz,
layer.get_cache_size(),
self.config.d_model,
dtype=param.dtype,
device=param.device,
)
hid.append(h)
return hid
def get_aux_loss(self):
return self.model.get_aux_loss()
def get_current_max_span(self):
return self.model.get_current_max_span()
def get_current_avg_span(self):
return self.model.get_current_avg_span()
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
new_order: torch.Tensor,
):
"""Reorder incremental state.
This will be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
raise NotImplementedError("This is required for generation/beam search")
# mems = self.get_incremental_state(incremental_state, "mems")
# if mems is not None:
# new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
# self.set_incremental_state(incremental_state, "mems", new_mems)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import torch
from fairseq import utils
from fairseq.data import (
Dictionary,
TokenBlockDataset,
data_utils,
iterators,
)
from fairseq.dataclass import FairseqDataclass
from fairseq.distributed import utils as dist_utils
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II
logger = logging.getLogger(__name__)
@dataclass
class TruncatedBPTTLMConfig(FairseqDataclass):
data: str = field(default="???", metadata={"help": "path to data directory"})
tokens_per_sample: int = field(
default=1024, metadata={"help": "max number of tokens per sequence"},
)
batch_size: int = II("dataset.batch_size")
# Some models use *max_target_positions* to know how many positional
# embeddings to learn. We use II(...) to make it default to
# *tokens_per_sample*, but in principle there could be more positional
# embeddings than tokens in a single batch. This may also be irrelevant for
# custom model implementations.
max_target_positions: int = II("task.tokens_per_sample")
# these will be populated automatically if not provided
data_parallel_rank: Optional[int] = None
data_parallel_size: Optional[int] = None
@register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
class TruncatedBPTTLMTask(FairseqTask):
def __init__(self, cfg: TruncatedBPTTLMConfig):
super().__init__(cfg)
if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
if torch.distributed.is_initialized():
cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
else:
cfg.data_parallel_rank = 0
cfg.data_parallel_size = 1
# load the dictionary
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(self.dictionary)))
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)"""
# support sharded datasets
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
# each element of *data* will be a tensorized line from the original
# text dataset, similar to ``open(split_path).readlines()``
data = data_utils.load_indexed_dataset(
split_path, self.dictionary, combine=combine
)
if data is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
# this is similar to ``data.view(-1).split(tokens_per_sample)``
data = TokenBlockDataset(
data,
data.sizes,
block_size=self.cfg.tokens_per_sample,
pad=None, # unused
eos=None, # unused
break_mode="none",
)
self.datasets[split] = TruncatedBPTTDataset(
data=data,
bsz_per_shard=self.cfg.batch_size,
shard_id=self.cfg.data_parallel_rank,
num_shards=self.cfg.data_parallel_size,
)
def dataset(self, split):
return self.datasets[split]
def get_batch_iterator(
self,
dataset,
num_workers=0,
epoch=1,
data_buffer_size=0,
skip_remainder_batch=False,
**kwargs
):
return iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=self._collate_fn,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size,
# we don't use the batching functionality from EpochBatchIterator;
# instead every item in *dataset* is a whole batch
batch_sampler=[[i] for i in range(len(dataset))],
disable_shuffling=True,
skip_remainder_batch=skip_remainder_batch,
)
def _collate_fn(self, items: List[List[torch.Tensor]]):
# we don't use fairseq's batching functionality, so we expect a single
# Tensor of type List[torch.Tensor]
assert len(items) == 1
# item will have shape B x T (the last batch may have length < T)
id, item = items[0]
item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
B, T = item.size()
# shift item one position over and append a padding token for the target
target = torch.nn.functional.pad(
item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
)
# fairseq expects batches to have the following structure
return {
"id": torch.tensor([id] * item.size(0)),
"net_input": {"src_tokens": item,},
"target": target,
"nsentences": item.size(0),
"ntokens": item.numel(),
}
def build_dataset_for_inference(
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
) -> torch.utils.data.Dataset:
eos = self.source_dictionary.eos()
dataset = TokenBlockDataset(
src_tokens,
src_lengths,
block_size=None, # ignored for "eos" break mode
pad=self.source_dictionary.pad(),
eos=eos,
break_mode="eos",
)
class Dataset(torch.utils.data.Dataset):
def __getitem__(self, i):
item = dataset[i]
if item[-1] == eos:
# remove eos to support generating with a prefix
item = item[:-1]
return (i, [item])
def __len__(self):
return len(dataset)
return Dataset()
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
if constraints is not None:
raise NotImplementedError
# SequenceGenerator doesn't use *src_tokens* directly, we need to
# pass the *prefix_tokens* argument instead.
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
prefix_tokens = sample["net_input"]["src_tokens"]
# begin generation with the end-of-sentence token
bos_token = self.source_dictionary.eos()
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
)
def eval_lm_dataloader(
self,
dataset,
max_tokens: Optional[int] = 36000,
batch_size: Optional[int] = None,
max_positions: Optional[int] = None,
num_shards: int = 1,
shard_id: int = 0,
num_workers: int = 1,
data_buffer_size: int = 10,
context_window: int = 0,
):
if context_window > 0:
raise NotImplementedError(
"Transformer-XL doesn't need --context-window, try "
"--model-overrides '{\"mem_len\":42}' instead "
)
return self.get_batch_iterator(
dataset=dataset,
max_tokens=max_tokens,
max_sentences=batch_size,
max_positions=max_positions,
ignore_invalid_inputs=True,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
data_buffer_size=data_buffer_size,
).next_epoch_itr(shuffle=False)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
class TruncatedBPTTDataset(torch.utils.data.Dataset):
def __init__(
self,
data: List[torch.Tensor], # ordered list of items
bsz_per_shard, # number of items processed per GPUs per forward
shard_id, # current GPU ID
num_shards, # number of GPUs
):
super().__init__()
self.data = data
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).contiguous()
return data
# total number of sequences processed by all GPUs in each forward pass
global_batch_size = bsz_per_shard * num_shards
"""
With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
*indices* might look like:
indices = [[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11]]
The size of the TruncatedBPTTDataset instance will be 2,
and shard 1 will see items:
[(0, [data[4], data[6]]),
(1, [data[5], data[7]])]
"""
indices = batchify(torch.arange(len(data)), global_batch_size)
assert indices.size(0) == global_batch_size
self.my_indices = indices[
shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
]
assert self.my_indices.size(0) == bsz_per_shard
def __len__(self):
return self.my_indices.size(1)
def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
return (i, [self.data[idx] for idx in self.my_indices[:, i]])
# Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling (Gong et al., 2021)
[https://arxiv.org/pdf/2106.10840.pdf](https://arxiv.org/pdf/2106.10840.pdf)
## Introduction
We present attention head selection strategies in multilingual and multi-domain sequence modeling including text translation, speech recognition and speech translation tasks.
Below is an example of training multilingual/multi-domain speech recognition models.
## Data Preparation
Prepare mTEDx data as in [mTEDx example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/mtedx_example.md) and CoVoST data as in [CoVoST example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/covost_example.md). Similarly prepare EuroParl data.
## Training a multilingual ASR model with attention head selection
```bash
data_dir=<path to mtedx data>
train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
valid_subset="valid_ar_ar_tedx,valid_de_de_tedx,valid_el_el_tedx,valid_es_es_tedx,valid_fr_fr_tedx,valid_it_it_tedx,valid_pt_pt_tedx,valid_ru_ru_tedx"
strateg=<subset or group>
fairseq-train ${data_dir} \
--user-dir examples/attention_head_selection/src \
--train-subset "${train_subset}" \
--valid-subset "${valid_subset}" \
--config-yaml 'config_asr.yaml' \
--arch 'head_selection_s2t_transformer_s' \
--task 'speech_to_text_head_selection' \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
--lr 5e-4 \
--clip-norm 10.0 \
--seed 1 \
--max-epoch 400 \
--max-tokens 32000 \
--ignore-prefix-size 1 \
--dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--skip-invalid-size-inputs-valid-test \
--encoder-attn-head-select \
--total-encoder-attention-heads 8 \
--decoder-self-attn-head-select \
--total-decoder-attention-heads 8 \
--attn-head-select-strategy ${strategy} \
--task-type lang \
```
## Training a multi-domain ASR model with attention head selection
```bash
data_dir=<path to multi-domain data>
train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
valid_subset="dev_es_es_tedx,dev_fr_fr_tedx,dev_pt_pt_tedx,dev_it_it_tedx,dev_ru_ru_tedx,dev_el_el_tedx,dev_ar_ar_tedx,dev_de_de_tedx,dev_ar_ar_cv,dev_de_de_cv,dev_es_es_cv,dev_fr_fr_cv,dev_it_it_cv,dev_pt_pt_cv,dev_ru_ru_cv,dev_de_de_ep,dev_es_es_ep,dev_fr_fr_ep,dev_it_it_ep,dev_pt_pt_ep"
strateg=<subset or group>
fairseq-train ${data_dir} \
--user-dir examples/attention_head_selection/src \
--train-subset "${train_subset}" \
--valid-subset "${valid_subset}" \
--config-yaml 'config_asr.yaml' \
--arch head_selection_s2t_transformer_s \
--task speech_to_text_head_selection \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
--lr 5e-4 \
--clip-norm 10.0 \
--seed 1 \
--max-epoch 400 \
--max-tokens 32000 \
--ignore-prefix-size 1 \
--dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--skip-invalid-size-inputs-valid-test \
--encoder-attn-head-select \
--total-encoder-attention-heads 8 \
--decoder-self-attn-head-select \
--total-decoder-attention-heads 8 \
--attn-head-select-strategy ${strategy} \
--task-type domain
```
## Inference in multilingual setting
```bash
MODEL_DIR=<checkpoint directory>
data_dir=<path to mtedx data>
gen_subset=<data to test, e.g., test_ar_ar_tedx>
train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
last_n=10
CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
CHECKPOINT="_avg"
RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
if [ ! -d $RESULTS ]; then
mkdir -p $RESULTS
fi;
python scripts/average_checkpoints.py \
--inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
--output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${data_dir} \
--user-dir examples/attention_head_selection/src \
--arch 'head_selection_s2t_transformer_s' \
--task 'speech_to_text_head_selection' \
--train-subset ${train_subset} \
--gen-subset ${gen_subset} \
--path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
--config-yaml 'config_asr.yaml' \
--prefix-size 1 \
--max-tokens 40000 --beam 5 \
--skip-invalid-size-inputs-valid-test \
--results-path ${RESULTS} \
--scoring wer --wer-tokenizer 13a \
--wer-lowercase --wer-remove-punct --remove-bpe
```
## Inference in multi-domain setting
```bash
MODEL_DIR=<checkpoint directory>
data_dir=<path to multi-domain data>
gen_subset=<data to test, e.g., test_pt_pt_cv>
train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
last_n=10
CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
CHECKPOINT="_avg"
RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
if [ ! -d $RESULTS ]; then
mkdir -p $RESULTS
fi;
python scripts/average_checkpoints.py \
--inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
--output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${data_dir} \
--user-dir examples/attention_head_selection/src \
--arch 'head_selection_s2t_transformer_s' \
--task 'speech_to_text_head_selection' \
--train-subset ${train_subset} \
--gen-subset ${gen_subset} \
--path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
--config-yaml 'config_asr.yaml' \
--prefix-size 1 \
--max-tokens 40000 --beam 5 \
--skip-invalid-size-inputs-valid-test \
--results-path ${RESULTS} \
--scoring wer --wer-tokenizer 13a \
--wer-lowercase --wer-remove-punct --remove-bpe
```
## Citation
```bibtex
@article{gong2021pay,
title={Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling},
author={Gong, Hongyu and Tang, Yun and Pino, Juan and Li, Xian},
journal={arXiv preprint arXiv:2106.10840},
year={2021}
}
'''
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from pathlib import Path
from typing import Dict, List, Optional
from dataclasses import dataclass
import torch
from fairseq.data import (
ConcatDataset,
Dictionary,
FairseqDataset,
ResamplingDataset
)
from fairseq.data.audio.data_cfg import S2TDataConfig
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDatasetItem,
SpeechToTextDataset,
SpeechToTextDatasetCreator
)
logger = logging.getLogger(__name__)
@dataclass
class SpeechToTextDatasetItemWithDomain(SpeechToTextDatasetItem):
src_lang_id: Optional[torch.Tensor] = None
tgt_lang_id: Optional[torch.Tensor] = None
domain_id: Optional[torch.Tensor] = None
class SpeechToTextDatasetWithDomain(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
src_lang_ids: Optional[List[int]] = None,
tgt_lang_ids: Optional[List[int]] = None,
domain_ids: Optional[List[int]] = None
):
super().__init__(
split, is_train_split, cfg, audio_paths, n_frames,
src_texts, tgt_texts, speakers, src_langs, tgt_langs,
ids, tgt_dict, pre_tokenizer, bpe_tokenizer,
n_frames_per_step, speaker_to_id
)
assert src_lang_ids is None or len(src_lang_ids) == self.n_samples
assert tgt_lang_ids is None or len(tgt_lang_ids) == self.n_samples
assert domain_ids is None or len(domain_ids) == self.n_samples
self.src_lang_ids = src_lang_ids
self.tgt_lang_ids = tgt_lang_ids
self.domain_ids = domain_ids
def __getitem__(self, index: int) -> SpeechToTextDatasetItemWithDomain:
item = super().__getitem__(index)
src_lang_id = self.src_lang_ids[index]
tgt_lang_id = self.tgt_lang_ids[index]
domain_id = self.domain_ids[index]
return SpeechToTextDatasetItemWithDomain(
index=item.index, source=item.source,
target=item.target, speaker_id=item.speaker_id,
src_lang_id=src_lang_id,
tgt_lang_id=tgt_lang_id,
domain_id=domain_id
)
def collater(
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
) -> Dict:
if len(samples) == 0:
return {}
out = super().collater(samples, return_order=True)
order = out["order"]
src_lang_ids = torch.tensor([x.src_lang_id for x in samples], dtype=torch.long).index_select(0, order)
tgt_lang_ids = torch.tensor([x.tgt_lang_id for x in samples], dtype=torch.long).index_select(0, order)
domain_ids = torch.tensor([x.domain_id for x in samples], dtype=torch.long).index_select(0, order)
out["src_lang_ids"] = src_lang_ids
out["tgt_lang_ids"] = tgt_lang_ids
out["domain_ids"] = domain_ids
if not return_order:
del out["order"]
return out
class SpeechToTextDatasetCreatorWithDomain(SpeechToTextDatasetCreator):
KEY_SRC_LANG_ID, KEY_TGT_LANG_ID = "src_lang_id", "tgt_lang_id"
KEY_DOMAIN_ID = "domain_id"
# default values
DEFAULT_SRC_LANG_ID, DEFAULT_TGT_LANG_ID, DEFAULT_DOMAIN_ID = 0, 0, 0
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id
) -> SpeechToTextDatasetWithDomain:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
src_lang_ids = [s.get(cls.KEY_SRC_LANG_ID, cls.DEFAULT_SRC_LANG_ID) for s in samples]
tgt_lang_ids = [s.get(cls.KEY_TGT_LANG_ID, cls.DEFAULT_TGT_LANG_ID) for s in samples]
domain_ids = [s.get(cls.KEY_DOMAIN_ID, cls.DEFAULT_DOMAIN_ID) for s in samples]
return SpeechToTextDatasetWithDomain(
split_name,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
src_lang_ids=src_lang_ids,
tgt_lang_ids=tgt_lang_ids,
domain_ids=domain_ids
)
@classmethod
def _load_samples_from_tsv(
cls,
root: str,
split: str,
src_lang_map,
tgt_lang_map,
domain_map
):
# metadata from split
_, src_lang, tgt_lang, domain = split.split("_")
src_lang_id = src_lang_map[src_lang]
tgt_lang_id = tgt_lang_map[tgt_lang]
domain_id = domain_map[domain]
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
for s in samples:
s.update({
cls.KEY_SRC_LANG_ID: src_lang_id,
cls.KEY_TGT_LANG_ID: tgt_lang_id,
cls.KEY_DOMAIN_ID: domain_id
})
return samples
@classmethod
def _from_tsv(
cls,
root: str,
cfg: S2TDataConfig,
split: str,
tgt_dict,
is_train_split: bool,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
src_lang_map: Dict[str, int],
tgt_lang_map: Dict[str, int],
domain_map: Dict[str, int]
) -> SpeechToTextDatasetItemWithDomain:
samples = cls._load_samples_from_tsv(
root, split, src_lang_map,
tgt_lang_map, domain_map
)
return cls._from_list(
split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer,
bpe_tokenizer, n_frames_per_step, speaker_to_id
)
@classmethod
def from_tsv(
cls,
root: str,
cfg: S2TDataConfig,
splits: str,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split: bool,
epoch: int,
seed: int,
src_lang_map: Dict[str, int],
tgt_lang_map: Dict[str, int],
domain_map: Dict[str, int],
n_frames_per_step: int = 1,
speaker_to_id=None
) -> SpeechToTextDatasetWithDomain:
datasets = [
cls._from_tsv(
root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, src_lang_map, tgt_lang_map, domain_map
)
for split in splits.split(",")
]
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
# temperature-based sampling
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
datasets = [
ResamplingDataset(
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
)
for r, d in zip(size_ratios, datasets)
]
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch.nn.modules.loss import _Loss
class HeadSelectionLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
self.kl_weight = getattr(args, "kl_weight", 0.0)
def forward(self, head_samples, sample_sizes, prior=0.5, eps=1e-7):
"""
head_scores: (num_tasks, num_layers, num_heads)
sample_sizes: (num_tasks, )
"""
kl_loss = (head_samples * (torch.log(head_samples + eps) - math.log(prior))).sum(-1).sum(-1)
kl_loss /= (torch.numel(head_samples) / head_samples.size(0))
kl_loss = self.kl_weight * torch.matmul(kl_loss, sample_sizes)
return kl_loss
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List, Optional
from pathlib import Path
import torch.nn as nn
from torch import Tensor
from fairseq import checkpoint_utils
from fairseq.models import register_model, register_model_architecture
from fairseq.utils import safe_hasattr
from fairseq.models.speech_to_text.s2t_transformer import (
S2TTransformerModel,
S2TTransformerEncoder,
TransformerDecoderScriptable
)
from fairseq.models.speech_to_text.s2t_transformer import base_architecture as s2t_base_architecture
from ..modules.attn_head_selector import AttnHeadSelector
from ..modules.head_selection_transformer_layer import HeadSelectionTransformerEncoderLayer
from .head_selection_transformer import HeadSelectionTransformerDecoder
logger = logging.getLogger(__name__)
@register_model("head_selection_s2t_transformer")
class HeadSelectionS2TTransformerModel(S2TTransformerModel):
"""
Head selection implemented in S2TTransformer
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
S2TTransformerModel.add_args(parser)
# encoder head selection
parser.add_argument(
"--encoder-attn-head-select",
action="store_true",
default=False,
help="encoder head selection"
)
parser.add_argument(
"--total-encoder-attention-heads",
type=int,
help="total number of encoder attention heads"
)
# decoder self attention selection
parser.add_argument(
"--decoder-self-attn-head-select",
action="store_true",
default=False,
help="decoder self-attention head selection"
)
# decoder-encoder attention selection
parser.add_argument(
"--dec-enc-attn-head-select",
action="store_true",
default=False,
help="decoder-encoder attention head selection"
)
parser.add_argument(
"--total-decoder-attention-heads",
type=int,
help="total number of decoder attention heads"
)
# selection strategy
parser.add_argument(
"--attn-head-select-strategy",
type=str,
help="attention head selection strategy, subset or group"
)
@classmethod
def build_encoder(cls, args):
if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
encoder = HeadSelectionS2TTransformerEncoder(args)
else:
encoder = S2TTransformerEncoder(args)
pretraining_path = getattr(args, "load_pretrained_encoder_from", None)
if pretraining_path is not None:
if not Path(pretraining_path).exists():
logger.warning(
f"skipped pretraining because {pretraining_path} does not exist"
)
else:
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=pretraining_path
)
logger.info(f"loaded pretrained encoder from: {pretraining_path}")
return encoder
@classmethod
def build_decoder(cls, args, task, embed_tokens):
if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select):
return HeadSelectionTransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
else:
return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
class HeadSelectionS2TTransformerEncoder(S2TTransformerEncoder):
def __init__(self, args):
super().__init__(args)
self.attn_head_selector = AttnHeadSelector(
args.encoder_tasks,
args.encoder_layers,
args.total_encoder_attention_heads,
args.encoder_attention_heads,
args.attn_head_select_strategy,
)
self.task_ids = None
self.transformer_layers = nn.ModuleList([
HeadSelectionTransformerEncoderLayer(args, layer_idx, attn_head_selector=self.attn_head_selector) for layer_idx in range(args.encoder_layers)
])
def set_task_ids(self, task_ids):
self.task_ids = task_ids
def _forward(self, src_tokens, src_lengths, return_all_hiddens=False):
self.attn_head_selector.head_select(self.task_ids)
return super()._forward(src_tokens, src_lengths, return_all_hiddens)
class HeadSelectionTransformerDecoderScriptable(HeadSelectionTransformerDecoder):
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
# call scriptable method from parent class
x, _ = self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
return x, None
@register_model_architecture(model_name="head_selection_s2t_transformer", arch_name="head_selection_s2t_transformer")
def base_architecture(args):
s2t_base_architecture(args)
args.encoder_attn_head_select = getattr(args, "encoder_attn_head_select", False)
args.decoder_self_attn_head_select = getattr(args, "decoder_self_attn_head_select", False)
args.dec_enc_attn_head_select = getattr(args, "dec_enc_attn_head_select", False)
args.total_encoder_attention_heads = getattr(args, "total_encoder_attention_heads", 8)
args.total_decoder_attention_heads = getattr(args, "total_decoder_attention_heads", 8)
args.attn_head_select_strategy = getattr(args, "attn_head_select_strategy", "group")
@register_model_architecture("head_selection_s2t_transformer", "head_selection_s2t_transformer_s")
def head_selection_s2t_transformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
from fairseq.utils import safe_hasattr
from fairseq.models.transformer import (
TransformerModel,
TransformerEncoder,
TransformerDecoder
)
from ..modules.attn_head_selector import AttnHeadSelector
from ..modules.head_selection_transformer_layer import (
HeadSelectionTransformerEncoderLayer,
HeadSelectionTransformerDecoderLayer
)
class HeadSelectionTransformerModel(TransformerModel):
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
@staticmethod
def add_args(parser):
TransformerModel.add_args(parser)
# encoder head selection
parser.add_argument(
"--encoder-attn-head-select",
action="store_true",
default=False,
help="encoder head selection"
)
parser.add_argument(
"--total-encoder-attention-heads",
type=int,
help="total number of encoder attention heads"
)
# decoder self attention
parser.add_argument(
"--decoder-self-attn-head-select",
action="store_true",
default=False,
help="decoder self-attention head selection"
)
# decoder-encoder attention
parser.add_argument(
"--dec-enc-attn-head-select",
action="store_true",
default=False,
help="decoder-encoder attention head selection"
)
parser.add_argument(
"--total-decoder-attention-heads",
type=int,
help="total number of decoder attention heads"
)
# selection strategy
parser.add_argument(
"--attn-head-select-strategy",
type=str,
help="attention head selection strategy, subset or group"
)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
return HeadSelectionTransformerEncoder(
args, src_dict, embed_tokens
)
else:
return TransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select):
return HeadSelectionTransformerDecoder(
args, tgt_dict, embed_tokens
)
else:
return TransformerDecoder(args, tgt_dict, embed_tokens)
class HeadSelectionTransformerEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
self.num_tasks = args.encoder_tasks
self.num_layers = args.encoder_layers
self.total_num_heads = args.total_encoder_attention_heads
self.num_heads = args.encoder_attention_heads
self.select_strategy = args.attn_head_select_strategy
super().__init__(args, dictionary, embed_tokens)
self.attn_head_selector = AttnHeadSelector(
self.num_tasks,
self.num_layers,
self.total_num_heads,
self.num_heads,
self.select_strategy
)
self.task_ids = None
self.layers = nn.ModuleList(
[self.build_encoder_layer(args, i) for i in range(args.encoder_layers)]
)
def set_task_ids(self, task_ids):
self.task_ids = task_ids
def build_encoder_layer(self, args, layer_idx=None):
return HeadSelectionTransformerEncoderLayer(
args,
layer_idx,
attn_head_selector=self.attn_head_selector
)
def forward(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
):
self.attn_head_selector.head_select(self.task_ids)
return super().forward(src_tokens, src_lengths, return_all_hiddens, token_embeddings)
class HeadSelectionTransformerDecoder(TransformerDecoder):
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
self.num_tasks = args.decoder_tasks
self.num_layers = args.decoder_layers
self.total_num_heads = args.total_decoder_attention_heads
self.num_heads = args.decoder_attention_heads
self.select_strategy = args.attn_head_select_strategy
super().__init__(
args, dictionary, embed_tokens,
no_encoder_attn=no_encoder_attn,
output_projection=output_projection
)
self.self_attn_head_selector = None
self.enc_attn_head_selector = None
if safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select:
self.self_attn_head_selector = AttnHeadSelector(
self.num_tasks,
self.num_layers,
self.total_num_heads,
self.num_heads,
self.select_strategy
)
if safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select:
self.enc_attn_head_selector = AttnHeadSelector(
self.num_tasks,
self.num_layers,
self.total_num_heads,
self.num_heads,
self.select_strategy
)
self.task_ids = None
self.layers = nn.ModuleList(
[
self.build_head_selection_decoder_layer(args, no_encoder_attn, idx) for idx in range(args.decoder_layers)
]
)
def set_task_ids(self, task_ids):
self.task_ids = task_ids
def build_head_selection_decoder_layer(self, args, no_encoder_attn=False, layer_idx=None):
return HeadSelectionTransformerDecoderLayer(
args,
layer_idx,
self.self_attn_head_selector,
self.enc_attn_head_selector,
no_encoder_attn=no_encoder_attn
)
def forward(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
if self.self_attn_head_selector is not None:
self.self_attn_head_selector.head_select(self.task_ids)
if self.enc_attn_head_selector is not None:
self.enc_attn_head_selector.head_select(self.task_ids)
return super().forward(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
features_only=features_only,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens
)
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import math
class AttnHeadSelector(nn.Module):
"""
Latent variable modeling of attention head selection
"""
def __init__(
self, num_tasks, num_layers,
total_num_heads, num_heads,
select_strategy="group",
head_select_temp=5.0
):
super(AttnHeadSelector, self).__init__()
self.num_tasks = num_tasks
self.num_layers = num_layers
self.total_num_heads = total_num_heads
self.num_heads = num_heads
self.select_strategy = select_strategy
self.temp = head_select_temp
self.head_logits = torch.nn.Parameter(
torch.Tensor(self.num_tasks, self.num_layers, total_num_heads),
requires_grad=True
)
nn.init.uniform_(
self.head_logits, a=math.log(0.01),
b=math.log(1.0)
)
def gumbel_sample(self, logits, tau=1.0):
gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
gumbels1 = (logits + gumbels1 - gumbels2) / tau
y_soft = gumbels1.sigmoid()
return y_soft
def subset_select(self, y_soft, topk, dim=-1):
top_values, top_inds = torch.topk(y_soft, k=topk, dim=dim)
top_ret = 1.0 - top_values.detach() + top_values
return top_inds.detach(), top_ret
def group_selet(self, y_soft, topk, dim=-1):
# top_values: (num_tasks, num_layers, topk)
top_values, top_inds = torch.max(
y_soft.view(self.num_tasks, self.num_layers, -1, topk), dim=2
)
top_inds = top_inds * topk + torch.arange(topk, device=top_inds.device).unsqueeze(0).unsqueeze(1)
top_ret = 1.0 - top_values.detach() + top_values
return top_inds.detach(), top_ret
def head_select(self, task_ids=None):
# gumbel_sample
self.head_samples = self.gumbel_sample(self.head_logits, tau=self.temp)
# head select
if self.select_strategy == "subset":
self.subset_heads, self.subset_weights = self.subset_select(
self.head_samples,
topk=self.num_heads,
)
elif self.select_strategy == "group":
self.subset_heads, self.subset_weights = self.group_selet(
self.head_samples,
topk=self.num_heads,
)
else:
raise ValueError("{} is not supported".format(self.select_strategy))
self.batch_subset = self.subset_heads[task_ids, :, :]
self.batch_weights = self.subset_weights[task_ids, :, :]
def forward(self, layer_idx):
assert layer_idx is not None
batch_subset = self.batch_subset[:, layer_idx, :]
batch_weights = self.batch_weights[:, layer_idx, :]
return batch_subset, batch_weights
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.utils import safe_getattr
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
from ..modules.multihead_attention_selection import MultiheadAttentionSelection
class HeadSelectionTransformerEncoderLayer(TransformerEncoderLayer):
def __init__(self, args, layer_idx, attn_head_selector=None):
super().__init__(args)
self.layer_idx = layer_idx
self.self_attn = self.build_self_attention_selection(
self.embed_dim, args, attn_head_selector
)
def build_self_attention_selection(self, embed_dim, args, attn_head_selector=None):
return MultiheadAttentionSelection(
embed_dim,
args.total_encoder_attention_heads,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
layer_idx=self.layer_idx,
attn_head_selector=attn_head_selector
)
class HeadSelectionTransformerDecoderLayer(TransformerDecoderLayer):
def __init__(
self,
args,
layer_idx,
self_attn_head_selector=None,
enc_attn_head_selector=None,
no_encoder_attn=False,
add_bias_kv=False,
add_zero_attn=False,
):
self.layer_idx = layer_idx
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
if self_attn_head_selector is not None:
self.self_attn = self.build_self_attention_selection(
self.embed_dim, args,
self_attn_head_selector=self_attn_head_selector,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn
)
if enc_attn_head_selector is not None:
self.encoder_attn = self.build_encoder_attention_selection(
self.embed_dim, args,
enc_attn_head_selector=enc_attn_head_selector
)
def build_self_attention_selection(
self, embed_dim, args, self_attn_head_selector=None,
add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttentionSelection(
embed_dim,
args.total_decoder_attention_heads,
args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not safe_getattr(args, "cross_self_attention"),
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
layer_idx=self.layer_idx,
attn_head_selector=self_attn_head_selector,
)
def build_encoder_attention_selection(self, embed_dim, args, enc_attn_head_selector=None):
return MultiheadAttentionSelection(
embed_dim,
args.total_decoder_attention_heads,
args.decoder_attention_heads,
kdim=args.encoder_embed_dim,
vdim=args.encoder_embed_dim,
dropout=args.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
layer_idx=self.layer_idx,
attn_head_selector=enc_attn_head_selector,
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional, Tuple
import torch
from fairseq import utils
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor, nn
from torch.nn import Parameter
from fairseq.modules.multihead_attention import MultiheadAttention
from ..modules.multihead_functional import multi_head_attention_forward
class MultiheadAttentionSelection(MultiheadAttention):
def __init__(
self,
embed_dim,
total_num_heads,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
layer_idx=0,
attn_head_selector=None
):
super().__init__(
embed_dim,
num_heads,
kdim=kdim,
vdim=vdim,
dropout=dropout,
bias=bias,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=self_attention,
encoder_decoder_attention=encoder_decoder_attention,
q_noise=q_noise,
qn_block_size=qn_block_size,
)
self.layer_idx = layer_idx
self.attn_head_selector = attn_head_selector
self.total_num_heads = total_num_heads
self.total_embed_dim = self.head_dim * total_num_heads
self.k_proj = quant_noise(
nn.Linear(self.kdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, self.total_embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, self.total_embed_dim))
else:
self.bias_k = self.bias_v = None
self.reset_parameters()
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
# subset_heads: Optional[Tensor] = None,
# subset_weights: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
subset_heads, subset_weights = self.attn_head_selector(self.layer_idx)
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert list(query.size()) == [tgt_len, bsz, self.embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if (
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
):
assert key is not None and value is not None
return multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.total_num_heads,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
subset_heads=subset_heads,
subset_weights=subset_weights
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.total_num_heads, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.total_num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.total_num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.total_num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.total_num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.total_num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.total_num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.total_num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.total_num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
# evaluation
if subset_heads is not None and subset_heads.numel() == 1:
subset_heads = subset_heads.repeat(bsz)
subset_weights = subset_weights.repeat(bsz)
if subset_heads is None:
attn = torch.bmm(attn_probs, v)
else:
# training with head selection
mixed_attn = torch.bmm(attn_probs, v).contiguous().view(bsz, self.total_num_heads, tgt_len, self.head_dim)
attn = torch.stack(
[mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
)
attn = attn * subset_weights.unsqueeze(2).unsqueeze(3)
attn = attn.contiguous().view(bsz * self.num_heads, tgt_len, self.head_dim)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
if subset_heads is None:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
else:
mixed_attn_weights = attn_weights_float.view(
bsz, self.total_num_heads, tgt_len, src_len
)
attn_weights = torch.stack(
[mixed_attn_weights[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn.functional import (
linear, softmax, dropout, pad,
has_torch_function,
handle_torch_function,
_in_projection_packed,
)
import math
import warnings
def _scaled_dot_product_attention(
q: Tensor,
k: Tensor,
v: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
bsz: int = 1,
subset_heads: Optional[Tensor] = None,
subset_weights: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
B, Nt, E = q.shape
q = q / math.sqrt(E)
# B: bsz * total_num_heads
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
attn = torch.bmm(q, k.transpose(-2, -1))
if attn_mask is not None:
attn += attn_mask
attn = softmax(attn, dim=-1)
if dropout_p > 0.0:
attn = dropout(attn, p=dropout_p)
if subset_heads is None:
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
output = torch.bmm(attn, v)
else:
mixed_output = torch.bmm(attn, v).contiguous().view(bsz, -1, Nt, E)
output = torch.stack(
[mixed_output[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))],
dim=1
)
output = output * subset_weights.unsqueeze(2).unsqueeze(3)
output = output.contiguous().view(-1, Nt, E)
if subset_heads is not None:
_, Nt, Ns = attn.size()
mixed_attn = attn.view(bsz, -1, Nt, Ns)
attn = torch.stack(
[mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
)
return output, attn
def _in_projection(
q: Tensor,
k: Tensor,
v: Tensor,
w_q: Tensor,
w_k: Tensor,
w_v: Tensor,
b_q: Optional[Tensor] = None,
b_k: Optional[Tensor] = None,
b_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
def multi_head_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
total_num_heads: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
subset_heads: Optional[Tensor] = None,
subset_weights: Optional[Tensor] = None,
):
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
if has_torch_function(tens_ops):
return handle_torch_function(
multi_head_attention_forward,
tens_ops,
query,
key,
value,
embed_dim_to_check,
total_num_heads,
num_heads,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias,
training=training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=use_separate_proj_weight,
q_proj_weight=q_proj_weight,
k_proj_weight=k_proj_weight,
v_proj_weight=v_proj_weight,
static_k=static_k,
static_v=static_v,
subset_heads=subset_heads,
subset_weights=subset_weights
)
# set up shape vars
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
# prep attention mask
if attn_mask is not None:
if attn_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
attn_mask = attn_mask.to(torch.bool)
else:
assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * total_num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# prep key padding mask
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
key_padding_mask = key_padding_mask.to(torch.bool)
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
assert bias_v is None
#
# reshape q, k, v for multihead attention and make em batch first
#
q = q.contiguous().view(tgt_len, bsz * total_num_heads, head_dim).transpose(0, 1)
if static_k is None:
k = k.contiguous().view(k.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * total_num_heads, \
f"expecting static_k.size(0) of {bsz * total_num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.contiguous().view(v.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * total_num_heads, \
f"expecting static_v.size(0) of {bsz * total_num_heads}, but got {static_v.size(0)}"
assert static_v.size(2) == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * total_num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
src_len = k.size(1)
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, total_num_heads, -1, -1).reshape(bsz * total_num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
elif attn_mask.dtype == torch.bool:
attn_mask = attn_mask.logical_or(key_padding_mask)
else:
attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
# convert mask to float
if attn_mask is not None and attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
# adjust dropout probability
if not training:
dropout_p = 0.0
#
# (deep breath) calculate attention and out projection
#
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, bsz, subset_heads, subset_weights)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.optim.amp_optimizer import AMPOptimizer
from fairseq.tasks import register_task
from fairseq.tasks.speech_to_text import SpeechToTextTask
from .data.speech_to_text_dataset_with_domain import SpeechToTextDatasetCreatorWithDomain
from .loss.attention_head_selection import HeadSelectionLoss
@register_task("speech_to_text_head_selection")
class SpeechToTextHeadSelectionTask(SpeechToTextTask):
@classmethod
def add_args(cls, parser):
SpeechToTextTask.add_args(parser)
parser.add_argument(
"--task-type",
type=str,
default="lang",
help="task type for head selection, lang or domain"
)
parser.add_argument(
"--kl-weight",
type=float,
default=0.0,
help="the weight of KL loss"
)
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.task_type = args.task_type
assert self.task_type in ["lang", "domain"], "invalid task_type: {}, should be either lang or domain".format(self.task_type)
self.map_task_to_id(args.train_subset)
self.encoder_head_prior = float(args.decoder_attention_heads) / args.total_decoder_attention_heads
self.decoder_head_prior = float(args.encoder_attention_heads) / args.total_encoder_attention_heads
self.kl_loss = HeadSelectionLoss(args)
def map_task_to_id(self, train_subset):
src_lang_set, tgt_lang_set, domain_set = set(), set(), set()
for split in train_subset.split(","):
seq = split.split("_")
assert len(seq) == 4, "subset {} should be in the format of train_src_tgt_domain".format(split)
_, src_lang, tgt_lang, domain = seq
src_lang_set.add(src_lang)
tgt_lang_set.add(tgt_lang)
domain_set.add(domain)
src_langs = sorted(src_lang_set)
tgt_langs = sorted(tgt_lang_set)
domains = sorted(domain_set)
self.src_lang_map = {src_lang: i for (i, src_lang) in enumerate(src_langs)}
self.tgt_lang_map = {tgt_lang: i for (i, tgt_lang) in enumerate(tgt_langs)}
self.domain_map = {domain: i for (i, domain) in enumerate(domains)}
if self.task_type == "lang":
self.encoder_tasks = len(self.src_lang_map)
self.decoder_tasks = len(self.tgt_lang_map)
elif self.task_type == "domain":
self.encoder_tasks = len(self.domain_map)
self.decoder_tasks = len(self.domain_map)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
is_train_split = split.startswith("train")
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
self.datasets[split] = SpeechToTextDatasetCreatorWithDomain.from_tsv(
self.args.data,
self.data_cfg,
split,
self.tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split=is_train_split,
epoch=epoch,
seed=self.args.seed,
src_lang_map=self.src_lang_map,
tgt_lang_map=self.tgt_lang_map,
domain_map=self.domain_map,
speaker_to_id=self.speaker_to_id
)
def build_model(self, args):
args.encoder_tasks = self.encoder_tasks
args.decoder_tasks = self.decoder_tasks
return super(SpeechToTextHeadSelectionTask, self).build_model(args)
def get_sample_sizes(self, sample, task_ids, num_tasks):
"""
task_ids: (bsz,)
get sample sizes for each task
"""
bsz = task_ids.size(0)
mat = torch.zeros((num_tasks, bsz), device=task_ids.device)
mat[task_ids, torch.arange(bsz)] = 1.0
ntokens = torch.sum(sample['target'] != 1, dim=-1)
sample_sizes = torch.matmul(mat, ntokens.float())
return sample_sizes
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
model.set_num_updates(update_num)
# task ids
if self.task_type == "lang":
encoder_task_ids = sample["src_lang_ids"]
decoder_task_ids = sample["tgt_lang_ids"]
elif self.task_type == "domain":
encoder_task_ids = sample["domain_ids"]
decoder_task_ids = sample["domain_ids"]
model.encoder.set_task_ids(encoder_task_ids)
model.decoder.set_task_ids(decoder_task_ids)
with torch.autograd.profiler.record_function("forward"):
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
loss, sample_size, logging_output = criterion(model, sample)
# KL loss
if self.args.encoder_attn_head_select:
sample_sizes = self.get_sample_sizes(sample, encoder_task_ids, self.encoder_tasks)
loss += self.kl_loss(
model.encoder.attn_head_selector.head_samples,
sample_sizes,
self.encoder_head_prior
)
if self.args.decoder_self_attn_head_select:
sample_sizes = self.get_sample_sizes(sample, decoder_task_ids, self.decoder_tasks)
loss += self.kl_loss(
model.decoder.self_attn_head_selector.head_samples,
sample_sizes,
self.decoder_head_prior
)
if self.args.dec_enc_attn_head_select:
sample_sizes = self.get_sample_sizes(sample, decoder_task_ids, self.decoder_tasks)
loss += self.kl_loss(
model.decoder.enc_attn_head_selector.head_sampes,
sample_sizes,
self.decoder_head_prior
)
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
model.eval()
# task ids
if self.task_type == "lang":
encoder_task_ids = sample["src_lang_ids"]
decoder_task_ids = sample["tgt_lang_ids"]
elif self.task_type == "domain":
encoder_task_ids = sample["domain_ids"]
decoder_task_ids = sample["domain_ids"]
model.encoder.set_task_ids(encoder_task_ids)
model.decoder.set_task_ids(decoder_task_ids)
with torch.no_grad():
loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
# task ids
if self.task_type == "lang":
encoder_task_ids = sample["src_lang_ids"][:1]
decoder_task_ids = sample["tgt_lang_ids"][:1]
elif self.task_type == "domain":
encoder_task_ids = sample["domain_ids"][:1]
decoder_task_ids = sample["domain_ids"][:1]
for model in models:
model.encoder.set_task_ids(encoder_task_ids)
model.decoder.set_task_ids(decoder_task_ids)
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, constraints=constraints
)
# End-to-end NLU
End-to-end spoken language understanding (SLU) predicts intent directly from audio using a single model. It promises to improve the performance of assistant systems by leveraging acoustic information lost in the intermediate textual representation and preventing cascading errors from Automatic Speech Recognition (ASR). Further, having one unified model has efficiency advantages when deploying assistant systems on-device.
This page releases the code for reproducing the results in [STOP: A dataset for Spoken Task Oriented Semantic Parsing](https://arxiv.org/abs/2207.10643)
The dataset can be downloaded here: [download link](https://dl.fbaipublicfiles.com/stop/stop.tar.gz)
The low-resource splits can be downloaded here: [download link](http://dl.fbaipublicfiles.com/stop/low_resource_splits.tar.gz)
## Pretrained models end-to-end NLU Models
| Speech Pretraining | ASR Pretraining | Test EM Accuracy | Tesst EM-Tree Accuracy | Link |
| ----------- | ----------- |----------|----------|----------|
| None | None | 36.54 | 57.01 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-none-none.pt) |
| Wav2Vec | None | 68.05 | 82.53 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-none.pt) |
| HuBERT | None | 68.40 | 82.85 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-none.pt) |
| Wav2Vec | STOP | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-stop.pt) |
| HuBERT | STOP | 69.23 | 82.87 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-stop.pt) |
| Wav2Vec | Librispeech | 68.47 | 82.49 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-ls.pt) |
| HuBERT | Librispeech | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-ls.pt) |
## Pretrained models ASR Models
| Speech Pre-training | ASR Dataset | STOP Eval WER | STOP Test WER | dev\_other WER | dev\_clean WER | test\_clean WER | test\_other WER | Link |
| ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- |
| HuBERT | Librispeech | 8.47 | 2.99 | 3.25 | 8.06 | 25.68 | 26.19 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls.pt) |
| Wav2Vec | Librispeech | 9.215 | 3.204 | 3.334 | 9.006 | 27.257 | 27.588 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls.pt) |
| HuBERT | STOP | 46.31 | 31.30 | 31.52 | 47.16 | 4.29 | 4.26 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-stop.pt) |
| Wav2Vec | STOP | 43.103 | 27.833 | 28.479 | 28.479 | 4.679 | 4.667 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-stop.pt) |
| HuBERT | Librispeech + STOP | 9.015 | 3.211 | 3.372 | 8.635 | 5.133 | 5.056 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls-stop.pt) |
| Wav2Vec | Librispeech + STOP | 9.549 | 3.537 | 3.625 | 9.514 | 5.59 | 5.562 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls-stop.pt) |
## Creating the fairseq datasets from STOP
First, create the audio file manifests and label files:
```
python examples/audio_nlp/nlu/generate_manifests.py --stop_root $STOP_DOWNLOAD_DIR/stop --output $FAIRSEQ_DATASET_OUTPUT/
```
Run `./examples/audio_nlp/nlu/create_dict_stop.sh $FAIRSEQ_DATASET_OUTPUT` to generate the fairseq dictionaries.
## Training an End-to-end NLU Model
Download a wav2vec or hubert model from [link](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert) or [link](https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec)
```
python fairseq_cli/hydra-train --config-dir examples/audio_nlp/nlu/configs/ --config-name nlu_finetuning task.data=$FAIRSEQ_DATA_OUTPUT model.w2v_path=$PRETRAINED_MODEL_PATH
```
# @package _group_
common:
fp16: true
log_format: json
log_interval: 10
tensorboard_logdir: tb
checkpoint:
no_epoch_checkpoints: true
best_checkpoint_metric: em_error
save_interval: 10
task:
_name: nlu_finetuning
data: ???
labels: parse
eval_wer_parse: true
autoregressive: true
dataset:
num_workers: 6
max_tokens: 1600000
skip_invalid_size_inputs_valid_test: true
valid_subset: eval,test
train_subset: train
validate_interval: 10
criterion:
_name: label_smoothed_cross_entropy
optimization:
max_update: 320000
lr: [0.0001]
sentence_avg: true
update_freq: [1]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05
model:
_name: wav2vec_seq2seq
w2v_path: ???
autoregressive: true
apply_mask: true
mask_prob: 0.5
mask_channel_prob: 0.5
mask_channel_length: 64
layerdrop: 0.1
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 0
#!/bin/bash
### Script handling creation of data binaries
### for model training within fairseq
fairseq_root="."
data_root=$1
train_prefix="${data_root}/train"
valid_prefix="${data_root}/eval"
test_prefix="${data_root}/test"
dest_dir="$data_root/"
#echo "src dict: $src_dict" > "$dest_dir/src_dict.txt"
#echo "trg dict: $tgt_dict" > "$dest_dir/tgt_dict.txt"
#--tgtdict $tgt_dict \
PYTHONPATH=$fairseq_root \
python $fairseq_root/fairseq_cli/preprocess.py \
--source-lang "parse" \
--trainpref "$train_prefix" \
--validpref "$valid_prefix" \
--destdir "$dest_dir" \
--only-source \
--dict-only \
--workers 60;
PYTHONPATH=$fairseq_root \
python $fairseq_root/fairseq_cli/preprocess.py \
--source-lang "ltr" \
--trainpref "$train_prefix" \
--validpref "$valid_prefix" \
--destdir "$dest_dir" \
--only-source \
--dict-only \
--workers 60;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment