"git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "9e36648284c47fda760e11f10bb0a6968b78ae8e"
Commit c6624fa6 authored by Pingchuan Ma's avatar Pingchuan Ma Committed by Facebook GitHub Bot
Browse files

Add LRS3 AV-ASR recipe (#3278)

Summary:
This PR adds AV-ASR recipe which contains sample implementations of training and evaluation pipelines for RNNT based automatic, visual, and audio-visual (ASR, VSR, AV-ASR) models on LRS3. This repository includes both streaming/non-streaming modes.

CC stavros99 xiaohui-zhang YumengTao mthrok nateanl hwangjeff

Pull Request resolved: https://github.com/pytorch/audio/pull/3278

Reviewed By: nateanl

Differential Revision: D46121550

Pulled By: mpc001

fbshipit-source-id: bb44b97ae25e87df2a73a707008be46af4ad0fc6
parent f41ba26d
<p align="center"><img width="160" src="doc/lip_white.png" alt="logo"></p>
<h1 align="center">RNN-T ASR/VSR/AV-ASR Examples</h1>
This repository contains sample implementations of training and evaluation pipelines for RNNT based automatic, visual, and audio-visual (ASR, VSR, AV-ASR) models on LRS3. This repository includes both streaming/non-streaming modes.
## Preparation
1. Setup the environment.
```
conda create -y -n autoavsr python=3.8
conda activate autoavsr
```
2. Install PyTorch nightly version (Pytorch, Torchvision, Torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages:
```Shell
pip install pytorch-lightning sentencepiece
```
3. Preprocess LRS3 to a cropped-face dataset from the [data_prep](./data_prep) folder.
4. Download models below to initialise ASR/VSR front-end.
### Training A/V-ASR model
- `[dataset_path]` is the directory for original dataset.
- `[label_path]` is the labels directory.
- `[modality]` is the input modality type, including `v`, `a`, and `av`.
- `[mode]` is the model type, including `online` and `offline`.
```Shell
python train.py --dataset-path [dataset_path] \
--label-path [label-path]
--pretrained-model-path [pretrained_model_path] \
--sp-model-path ./spm_unigram_1023.model
--exp-dir ./exp \
--num-nodes 8 \
--gpus 8 \
--md [modality] \
--mode [mode]
```
### Training AV-ASR model
```Shell
python train.py --dataset-path [dataset_path] \
--label-path [label-path]
--pretrained-vid-model-path [pretrained_vid_model_path] \
--pretrained-aud-model-path [pretrained_aud_model_path] \
--sp-model-path ./spm_unigram_1023.model
--exp-dir ./exp \
--num-nodes 8 \
--gpus 8 \
--md av \
--mode [mode]
```
### Evaluating models
```Shell
python eval.py --dataset-path [dataset_path] \
--label-path [label-path]
--pretrained-model-path [pretrained_model_path] \
--sp-model-path ./spm_unigram_1023.model
--md [modality] \
--mode [mode] \
--checkpoint-path [checkpoint_path]
```
The table below contains WER for AV-ASR models.
| Model | WER [%] | Params (M) |
|:-----------:|:------------:|:--------------:|
| Non-streaming models | |
| AV-ASR | 4.2 | 50 |
| Streaming models | |
| AV-ASR | 4.9 | 40 |
import os
import torch
def average_checkpoints(last):
avg = None
for path in last:
states = torch.load(path, map_location=lambda storage, loc: storage)["state_dict"]
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
if avg[k].is_floating_point():
avg[k] /= len(last)
else:
avg[k] //= len(last)
return avg
def ensemble(args):
last = [
os.path.join(args.exp_dir, args.experiment_name, f"epoch={n}.ckpt")
for n in range(args.epochs - 10, args.epochs)
]
model_path = os.path.join(args.exp_dir, args.experiment_name, f"model_avg_10.pth")
torch.save({"state_dict": average_checkpoints(last)}, model_path)
import os
import random
import torch
import torchaudio
from lrs3 import LRS3
from pytorch_lightning import LightningDataModule
def _batch_by_token_count(idx_target_lengths, max_frames, batch_size=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > max_frames or (batch_size and len(current_batch) == batch_size):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset,
lengths,
max_frames,
num_buckets,
shuffle=False,
batch_size=None,
):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_frames >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets],
max_frames,
batch_size=batch_size,
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn
def __getitem__(self, idx):
return self.transform_fn(self.dataset[idx])
def __len__(self):
return len(self.dataset)
class LRS3DataModule(LightningDataModule):
def __init__(
self,
*,
args,
train_transform,
val_transform,
test_transform,
max_frames,
batch_size=None,
train_num_buckets=50,
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.args = args
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
self.max_frames = max_frames
self.batch_size = batch_size
self.train_num_buckets = train_num_buckets
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
datasets = [LRS3(self.args, subset="train")]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [dataset._lengthlist for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_frames,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
)
return dataloader
def val_dataloader(self):
datasets = [LRS3(self.args, subset="val")]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [dataset._lengthlist for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_frames,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
def test_dataloader(self):
dataset = LRS3(self.args, subset="test")
dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
We provide a pre-processing pipeline to detect and crop full-face images in this repository.
import logging
from argparse import ArgumentParser
import sentencepiece as spm
import torch
import torchaudio
from transforms import get_data_module
logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.md == "av":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
else:
from lightning import ConformerRNNTModule
model = ConformerRNNTModule(args, sp_model)
ckpt = torch.load(args.checkpoint_path, map_location=lambda storage, loc: storage)["state_dict"]
model.load_state_dict(ckpt)
model.eval()
return model
def run_eval(model, data_module):
total_edit_distance = 0
total_length = 0
dataloader = data_module.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][-1]
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.warning(f"Final WER: {total_edit_distance / total_length}")
return total_edit_distance / total_length
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--md",
type=str,
help="Modality",
required=True,
)
parser.add_argument(
"--mode",
type=str,
help="Perform online or offline recognition.",
required=True,
)
parser.add_argument(
"--dataset-path",
type=str,
help="Path to LRW audio-visual datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=str,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--checkpoint-path",
type=str,
help="Path to checkpoint model.",
required=True,
)
parser.add_argument(
"--pretrained-model-path",
type=str,
help="Path to Pretraned model.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = parse_args()
init_logger(args.debug)
model = get_lightning_module(args)
data_module = get_data_module(args, str(args.sp_model_path))
run_eval(model, data_module)
if __name__ == "__main__":
cli_main()
import itertools
import logging
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from models.conformer_rnnt import conformer_rnnt
from models.emformer_rnnt import emformer_rnnt
from models.resnet import video_resnet
from models.resnet1d import audio_resnet
from pytorch_lightning import LightningModule
from schedulers import WarmupCosineScheduler
from torchaudio.models import Hypothesis, RNNTBeamSearch
_expected_spm_vocab_size = 1023
Batch = namedtuple("Batch", ["inputs", "input_lengths", "targets", "target_lengths"])
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class ConformerRNNTModule(LightningModule):
def __init__(self, args=None, sp_model=None, pretrained_model_path=None):
super().__init__()
self.save_hyperparameters(args)
self.args = args
self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
if args.md == "v":
self.frontend = video_resnet()
if args.md == "a":
self.frontend = audio_resnet()
if args.mode == "online":
self.model = emformer_rnnt()
if args.mode == "offline":
self.model = conformer_rnnt()
# -- initialise
if args.pretrained_model_path:
ckpt = torch.load(args.pretrained_model_path, map_location=lambda storage, loc: storage)
tmp_ckpt = {
k.replace("encoder.frontend.", ""): v for k, v in ckpt.items() if k.startswith("encoder.frontend.")
}
self.frontend.load_state_dict(tmp_ckpt)
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum")
self.optimizer = torch.optim.AdamW(
itertools.chain(*([self.frontend.parameters(), self.model.parameters()])),
lr=8e-4,
weight_decay=0.06,
betas=(0.9, 0.98),
)
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
features = self.frontend(batch.inputs)
output, src_lengths, _, _ = self.model(
features, batch.input_lengths, prepended_targets, prepended_target_lengths
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
self.warmup_lr_scheduler = WarmupCosineScheduler(
self.optimizer,
10,
self.args.epochs,
len(self.trainer.datamodule.train_dataloader()) / self.trainer.num_devices / self.trainer.num_nodes,
)
self.lr_scheduler_interval = "step"
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
x = self.frontend(batch.inputs.to(self.device))
hypotheses = decoder(x, batch.input_lengths.to(self.device), beam_width=20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.inputs.size(0)
batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", self.global_step)
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
import itertools
import logging
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from models.conformer_rnnt import conformer_rnnt
from models.emformer_rnnt import emformer_rnnt
from models.fusion import fusion_module
from models.resnet import video_resnet
from models.resnet1d import audio_resnet
from pytorch_lightning import LightningModule
from schedulers import WarmupCosineScheduler
from torchaudio.models import Hypothesis, RNNTBeamSearch
_expected_spm_vocab_size = 1023
AVBatch = namedtuple("AVBatch", ["audios", "videos", "audio_lengths", "video_lengths", "targets", "target_lengths"])
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class AVConformerRNNTModule(LightningModule):
def __init__(self, args=None, sp_model=None):
super().__init__()
self.save_hyperparameters(args)
self.args = args
self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
self.audio_frontend = audio_resnet()
self.video_frontend = video_resnet()
self.fusion = fusion_module()
frontend_params = [self.video_frontend.parameters(), self.audio_frontend.parameters()]
fusion_params = [self.fusion.parameters()]
if args.mode == "online":
self.model = emformer_rnnt()
if args.mode == "offline":
self.model = conformer_rnnt()
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum")
self.optimizer = torch.optim.AdamW(
itertools.chain(*([self.model.parameters()] + frontend_params + fusion_params)),
lr=8e-4,
weight_decay=0.06,
betas=(0.9, 0.98),
)
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
video_features = self.video_frontend(batch.videos)
audio_features = self.audio_frontend(batch.audios)
output, src_lengths, _, _ = self.model(
self.fusion(torch.cat([video_features, audio_features], dim=-1)),
batch.video_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
self.warmup_lr_scheduler = WarmupCosineScheduler(
self.optimizer,
10,
self.args.epochs,
len(self.trainer.datamodule.train_dataloader()) / self.trainer.num_devices / self.trainer.num_nodes,
)
self.lr_scheduler_interval = "step"
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": self.lr_scheduler_interval}],
)
def forward(self, batch: AVBatch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
video_features = self.video_frontend(batch.videos.to(self.device))
audio_features = self.audio_frontend(batch.audios.to(self.device))
hypotheses = decoder(
self.fusion(torch.cat([video_features, audio_features], dim=-1)),
batch.video_lengths.to(self.device),
beam_width=20,
)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: AVBatch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.videos.size(0)
batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", self.global_step)
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
import os
from pathlib import Path
from typing import Tuple, Union
import torch
import torchaudio
import torchvision
from torch import Tensor
from torch.utils.data import Dataset
def _load_list(args, *filenames):
output = []
length = []
for filename in filenames:
filepath = os.path.join(os.path.dirname(args.dataset_path), filename)
for line in open(filepath).read().splitlines():
dataset_name, rel_path, input_length = line.split(",")[0], line.split(",")[1], line.split(",")[2]
path = os.path.normpath(os.path.join(args.dataset_path, rel_path[:-4] + ".mp4"))
length.append(int(input_length))
output.append(path)
return output, length
def load_video(path):
"""
rtype: torch, T x C x H x W
"""
vid = torchvision.io.read_video(path, pts_unit="sec", output_format="THWC")[0]
vid = vid.permute((0, 3, 1, 2))
return vid
def load_audio(path):
"""
rtype: torch, T x 1
"""
waveform, sample_rate = torchaudio.load(path, normalize=True)
return waveform.transpose(1, 0)
def load_transcript(path):
transcript_path = path.replace("video_seg", "text_seg")[:-4] + ".txt"
return open(transcript_path).read().splitlines()[0]
def load_item(path, md):
if md == "v":
return (load_video(path), load_transcript(path))
if md == "a":
return (load_audio(path), load_transcript(path))
if md == "av":
return (load_audio(path), load_video(path), load_transcript(path))
class LRS3(Dataset):
def __init__(
self,
args,
subset: str = "train",
) -> None:
if subset is not None and subset not in ["train", "val", "test"]:
raise ValueError("When `subset` is not None, it must be one of ['train', 'val', 'test'].")
self.args = args
if subset == "train":
self._filelist, self._lengthlist = _load_list(self.args, "train_transcript_lengths_seg16s.csv")
if subset == "val":
self._filelist, self._lengthlist = _load_list(self.args, "test_transcript_lengths_seg16s.csv")
if subset == "test":
self._filelist, self._lengthlist = _load_list(self.args, "test_transcript_lengths_seg16s.csv")
def __getitem__(self, n):
path = self._filelist[n]
return load_item(path, self.args.md)
def __len__(self) -> int:
return len(self._filelist)
from torchaudio.prototype.models import conformer_rnnt_model
# https://pytorch.org/audio/master/_modules/torchaudio/prototype/models/rnnt.html#conformer_rnnt_model
def conformer_rnnt():
return conformer_rnnt_model(
input_dim=512,
encoding_dim=1024,
time_reduction_stride=1,
conformer_input_dim=256,
conformer_ffn_dim=1024,
conformer_num_layers=16,
conformer_num_heads=4,
conformer_depthwise_conv_kernel_size=31,
conformer_dropout=0.1,
num_symbols=1024,
symbol_embedding_dim=256,
num_lstm_layers=2,
lstm_hidden_dim=512,
lstm_layer_norm=True,
lstm_layer_norm_epsilon=1e-5,
lstm_dropout=0.3,
joiner_activation="tanh",
)
from torchaudio.models.rnnt import emformer_rnnt_model
# https://pytorch.org/audio/master/_modules/torchaudio/models/rnnt.html#emformer_rnnt_base
def emformer_rnnt():
return emformer_rnnt_model(
input_dim=512,
encoding_dim=1024,
num_symbols=1024,
segment_length=64,
right_context_length=0,
time_reduction_input_dim=128,
time_reduction_stride=1,
transformer_num_heads=4,
transformer_ffn_dim=2048,
transformer_num_layers=20,
transformer_dropout=0.1,
transformer_activation="gelu",
transformer_left_context_length=30,
transformer_max_memory_size=0,
transformer_weight_init_scale_strategy="depthwise",
transformer_tanh_on_mem=True,
symbol_embedding_dim=512,
num_lstm_layers=3,
lstm_layer_norm=True,
lstm_layer_norm_epsilon=1e-3,
lstm_dropout=0.3,
)
import torch
class FeedForwardModule(torch.nn.Module):
r"""Positionwise feed forward layer.
Args:
input_dim (int): input dimension.
hidden_dim (int): hidden dimension.
dropout (float, optional): dropout probability. (Default: 0.0)
"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.0) -> None:
super().__init__()
self.sequential = torch.nn.Sequential(
torch.nn.LayerNorm(input_dim),
torch.nn.Linear(input_dim, hidden_dim, bias=True),
torch.nn.SiLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(hidden_dim, output_dim, bias=True),
torch.nn.Dropout(dropout),
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
r"""
Args:
input (torch.Tensor): with shape `(*, D)`.
Returns:
torch.Tensor: output, with shape `(*, D)`.
"""
return self.sequential(input)
def fusion_module():
return FeedForwardModule(1024, 3072, 512, 0.1)
import torch.nn as nn
def conv3x3(in_planes, out_planes, stride=1):
"""conv3x3.
:param in_planes: int, number of channels in the input sequence.
:param out_planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
def downsample_basic_block(inplanes, outplanes, stride):
"""downsample_basic_block.
:param inplanes: int, number of channels in the input sequence.
:param outplanes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Sequential(
nn.Conv2d(
inplanes,
outplanes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(outplanes),
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
relu_type="swish",
):
"""__init__.
:param inplanes: int, number of channels in the input sequence.
:param planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
:param downsample: boolean, if True, the temporal resolution is downsampled.
:param relu_type: str, type of activation function.
"""
super(BasicBlock, self).__init__()
assert relu_type in ["relu", "prelu", "swish"]
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
if relu_type == "relu":
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
elif relu_type == "prelu":
self.relu1 = nn.PReLU(num_parameters=planes)
self.relu2 = nn.PReLU(num_parameters=planes)
elif relu_type == "swish":
self.relu1 = nn.SiLU(inplace=True)
self.relu2 = nn.SiLU(inplace=True)
else:
raise NotImplementedError
# --------
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block,
layers,
relu_type="swish",
):
super(ResNet, self).__init__()
self.inplanes = 64
self.relu_type = relu_type
self.downsample_block = downsample_basic_block
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def _make_layer(self, block, planes, blocks, stride=1):
"""_make_layer.
:param block: torch.nn.Module, class of blocks.
:param planes: int, number of channels produced by the convolution.
:param blocks: int, number of layers in a block.
:param stride: int, size of the convolving kernel.
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = self.downsample_block(
inplanes=self.inplanes,
outplanes=planes * block.expansion,
stride=stride,
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
relu_type=self.relu_type,
)
)
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
relu_type=self.relu_type,
)
)
return nn.Sequential(*layers)
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
"""
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
# -- auxiliary functions
def threeD_to_2D_tensor(x):
n_batch, n_channels, s_time, sx, sy = x.shape
x = x.transpose(1, 2)
return x.reshape(n_batch * s_time, n_channels, sx, sy)
class Conv3dResNet(nn.Module):
"""Conv3dResNet module"""
def __init__(self, backbone_type="resnet", relu_type="swish"):
"""__init__.
:param backbone_type: str, the type of a visual front-end.
:param relu_type: str, activation function used in an audio front-end.
"""
super(Conv3dResNet, self).__init__()
self.backbone_type = backbone_type
self.frontend_nout = 64
self.trunk = ResNet(
BasicBlock,
[2, 2, 2, 2],
relu_type=relu_type,
)
# -- frontend3D
if relu_type == "relu":
frontend_relu = nn.ReLU(True)
elif relu_type == "prelu":
frontend_relu = nn.PReLU(self.frontend_nout)
elif relu_type == "swish":
frontend_relu = nn.SiLU(inplace=True)
self.frontend3D = nn.Sequential(
nn.Conv3d(
in_channels=1,
out_channels=self.frontend_nout,
kernel_size=(5, 7, 7),
stride=(1, 2, 2),
padding=(2, 3, 3),
bias=False,
),
nn.BatchNorm3d(self.frontend_nout),
frontend_relu,
nn.MaxPool3d(
kernel_size=(1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
),
)
def forward(self, xs_pad):
"""forward.
:param xs_pad: torch.Tensor, batch of padded input sequences.
"""
# -- include Channel dimension
xs_pad = xs_pad.transpose(2, 1)
B, C, T, H, W = xs_pad.size()
xs_pad = self.frontend3D(xs_pad)
Tnew = xs_pad.shape[2] # outpu should be B x C2 x Tnew x H x W
xs_pad = threeD_to_2D_tensor(xs_pad)
xs_pad = self.trunk(xs_pad)
xs_pad = xs_pad.view(B, Tnew, xs_pad.size(1))
return xs_pad
def video_resnet():
return Conv3dResNet()
import torch.nn as nn
def conv3x3(in_planes, out_planes, stride=1):
"""conv3x3.
:param in_planes: int, number of channels in the input sequence.
:param out_planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Conv1d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
def downsample_basic_block(inplanes, outplanes, stride):
"""downsample_basic_block.
:param inplanes: int, number of channels in the input sequence.
:param outplanes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
"""
return nn.Sequential(
nn.Conv1d(
inplanes,
outplanes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm1d(outplanes),
)
class BasicBlock1D(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
relu_type="relu",
):
"""__init__.
:param inplanes: int, number of channels in the input sequence.
:param planes: int, number of channels produced by the convolution.
:param stride: int, size of the convolving kernel.
:param downsample: boolean, if True, the temporal resolution is downsampled.
:param relu_type: str, type of activation function.
"""
super(BasicBlock1D, self).__init__()
assert relu_type in ["relu", "prelu", "swish"]
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm1d(planes)
# type of ReLU is an input option
if relu_type == "relu":
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
elif relu_type == "prelu":
self.relu1 = nn.PReLU(num_parameters=planes)
self.relu2 = nn.PReLU(num_parameters=planes)
elif relu_type == "swish":
self.relu1 = nn.SiLU(inplace=True)
self.relu2 = nn.SiLU(inplace=True)
else:
raise NotImplementedError
# --------
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm1d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T)
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out)
return out
class ResNet1D(nn.Module):
def __init__(
self,
block,
layers,
relu_type="swish",
a_upsample_ratio=1,
):
"""__init__.
:param block: torch.nn.Module, class of blocks.
:param layers: List, customised layers in each block.
:param relu_type: str, type of activation function.
:param a_upsample_ratio: int, The ratio related to the \
temporal resolution of output features of the frontend. \
a_upsample_ratio=1 produce features with a fps of 25.
"""
super(ResNet1D, self).__init__()
self.inplanes = 64
self.relu_type = relu_type
self.downsample_block = downsample_basic_block
self.a_upsample_ratio = a_upsample_ratio
self.conv1 = nn.Conv1d(
in_channels=1,
out_channels=self.inplanes,
kernel_size=80,
stride=4,
padding=38,
bias=False,
)
self.bn1 = nn.BatchNorm1d(self.inplanes)
if relu_type == "relu":
self.relu = nn.ReLU(inplace=True)
elif relu_type == "prelu":
self.relu = nn.PReLU(num_parameters=self.inplanes)
elif relu_type == "swish":
self.relu = nn.SiLU(inplace=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool1d(
kernel_size=20 // self.a_upsample_ratio,
stride=20 // self.a_upsample_ratio,
)
def _make_layer(self, block, planes, blocks, stride=1):
"""_make_layer.
:param block: torch.nn.Module, class of blocks.
:param planes: int, number of channels produced by the convolution.
:param blocks: int, number of layers in a block.
:param stride: int, size of the convolving kernel.
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = self.downsample_block(
inplanes=self.inplanes,
outplanes=planes * block.expansion,
stride=stride,
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
relu_type=self.relu_type,
)
)
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
relu_type=self.relu_type,
)
)
return nn.Sequential(*layers)
def forward(self, x):
"""forward.
:param x: torch.Tensor, input tensor with input size (B, C, T)
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
return x
class Conv1dResNet(nn.Module):
"""Conv1dResNet"""
def __init__(self, relu_type="swish", a_upsample_ratio=1):
"""__init__.
:param relu_type: str, Activation function used in an audio front-end.
:param a_upsample_ratio: int, The ratio related to the \
temporal resolution of output features of the frontend. \
a_upsample_ratio=1 produce features with a fps of 25.
"""
super(Conv1dResNet, self).__init__()
self.a_upsample_ratio = a_upsample_ratio
self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type, a_upsample_ratio=a_upsample_ratio)
def forward(self, xs_pad):
"""forward.
:param xs_pad: torch.Tensor, batch of padded input sequences (B, Tmax, idim)
"""
B, T, C = xs_pad.size()
xs_pad = xs_pad[:, : T // 640 * 640, :]
xs_pad = xs_pad.transpose(1, 2)
xs_pad = self.trunk(xs_pad)
# -- from B x C x T to B x T x C
xs_pad = xs_pad.transpose(1, 2)
return xs_pad
def audio_resnet():
return Conv1dResNet()
import math
import torch
class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
total_epochs: int,
steps_per_epoch: int,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_epochs * steps_per_epoch
self.total_steps = total_epochs * steps_per_epoch
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count < self.warmup_steps:
return [self._step_count / self.warmup_steps * base_lr for base_lr in self.base_lrs]
else:
decay_steps = self.total_steps - self.warmup_steps
return [
0.5 * base_lr * (1 + math.cos(math.pi * (self._step_count - self.warmup_steps) / decay_steps))
for base_lr in self.base_lrs
]
import logging
import os
import pathlib
from argparse import ArgumentParser
import sentencepiece as spm
from average_checkpoints import ensemble
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from transforms import get_data_module
def get_trainer(args):
seed_everything(1)
checkpoint = ModelCheckpoint(
dirpath=os.path.join(args.exp_dir, args.experiment_name) if args.exp_dir else None,
monitor="monitoring_step",
mode="max",
save_last=True,
filename=f"{{epoch}}",
save_top_k=10,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [
checkpoint,
lr_monitor,
]
return Trainer(
sync_batchnorm=True,
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.num_nodes,
devices=args.gpus,
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
resume_from_checkpoint=args.resume_from_checkpoint,
)
def get_lightning_module(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
if args.md == "av":
from lightning_av import AVConformerRNNTModule
model = AVConformerRNNTModule(args, sp_model)
else:
from lightning import ConformerRNNTModule
model = ConformerRNNTModule(args, sp_model)
return model
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--md",
type=str,
help="Modality",
required=True,
)
parser.add_argument(
"--mode",
type=str,
help="Perform online or offline recognition.",
required=True,
)
parser.add_argument(
"--dataset-path",
type=str,
help="Path to LRW audio-visual datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=str,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--pretrained-model-path",
type=str,
help="Path to Pretraned model.",
)
parser.add_argument(
"--exp-dir",
type=str,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--experiment-name",
default="online_avsr_public_test",
type=str,
help="Experiment name",
)
parser.add_argument(
"--num-nodes",
default=8,
type=int,
help="Number of nodes to use for training. (Default: 8)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=55,
type=int,
help="Number of epochs to train for. (Default: 55)",
)
parser.add_argument(
"--resume-from-checkpoint", default=None, type=str, help="Path to the checkpoint to resume from"
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = parse_args()
init_logger(args.debug)
model = get_lightning_module(args)
data_module = get_data_module(args, str(args.sp_model_path))
trainer = get_trainer(args)
trainer.fit(model, data_module)
ensemble(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LRS3 pretrain and trainval.
Example:
python train_spm.py --lrs3-path <LRS3-DIRECTORY>
"""
import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
def get_transcript_text(transcript_path):
return [open(transcript_path).read().splitlines()[0].lower()]
def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=1023,
model_type="unigram",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def parse_args():
default_output_path = "./spm_unigram_1023.model"
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--lrs3-path",
type=pathlib.Path,
help="Path to LRS3 datasets.",
required=True,
)
parser.add_argument(
"--output-file",
default=pathlib.Path(default_output_path),
type=pathlib.Path,
help=f"File to save model to. (Default: '{default_output_path}')",
)
return parser.parse_args()
def run_cli():
args = parse_args()
root = args.lrs3_path / "LRS3_text_seg24s"
splits = ["pretrain", "trainval"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)
model = train_spm(merged_transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
if __name__ == "__main__":
run_cli()
import json
import math
import random
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
import torchvision
from data_module import LRS3DataModule
from lightning import Batch
from lightning_av import AVBatch
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class AdaptiveTimeMask(torch.nn.Module):
def __init__(self, window, stride):
super().__init__()
self.window = window
self.stride = stride
def forward(self, x):
cloned = x.clone()
length = cloned.size(1)
n_mask = int((length + self.stride - 0.1) // self.stride)
ts = torch.randint(0, self.window, size=(n_mask, 2))
for t, t_end in ts:
if length - t <= 0:
continue
t_start = random.randrange(0, length - t)
if t_start == t_start + t:
continue
t_end += t_start
cloned[:, t_start:t_end] = 0
return cloned
def _extract_labels(sp_model, samples: List):
targets = [sp_model.encode(sample[-1].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _extract_features(video_pipeline, audio_pipeline, samples, args):
raw_videos = []
raw_audios = []
for sample in samples:
if args.md == "v":
raw_videos.append(sample[0])
if args.md == "a":
raw_audios.append(sample[0])
if args.md == "av":
length = min(len(sample[0]) // 640, len(sample[1]))
raw_audios.append(sample[0][: length * 640])
raw_videos.append(sample[1][:length])
if args.md == "v" or args.md == "av":
videos = torch.nn.utils.rnn.pad_sequence(raw_videos, batch_first=True)
videos = video_pipeline(videos)
video_lengths = torch.tensor([elem.shape[0] for elem in videos], dtype=torch.int32)
if args.md == "a" or args.md == "av":
audios = torch.nn.utils.rnn.pad_sequence(raw_audios, batch_first=True)
audios = audio_pipeline(audios)
audio_lengths = torch.tensor([elem.shape[0] // 640 for elem in audios], dtype=torch.int32)
if args.md == "v":
return videos, video_lengths
if args.md == "a":
return audios, audio_lengths
if args.md == "av":
return audios, videos, audio_lengths, video_lengths
class TrainTransform:
def __init__(self, sp_model_path: str, args):
self.args = args
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.train_video_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: x / 255.0),
torchvision.transforms.RandomCrop(88),
torchvision.transforms.RandomHorizontalFlip(0.5),
FunctionalModule(lambda x: x.transpose(0, 1)),
torchvision.transforms.Grayscale(),
FunctionalModule(lambda x: x.transpose(0, 1)),
AdaptiveTimeMask(10, 25),
torchvision.transforms.Normalize(0.421, 0.165),
)
self.train_audio_pipeline = torch.nn.Sequential(
AdaptiveTimeMask(10, 25),
)
def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples)
if self.args.md == "a":
audios, audio_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.md == "v":
videos, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return Batch(videos, video_lengths, targets, target_lengths)
if self.args.md == "av":
audios, videos, audio_lengths, video_lengths = _extract_features(
self.train_video_pipeline, self.train_audio_pipeline, samples, self.args
)
return AVBatch(audios, videos, audio_lengths, video_lengths, targets, target_lengths)
class ValTransform:
def __init__(self, sp_model_path: str, args):
self.args = args
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.valid_video_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: x / 255.0),
torchvision.transforms.CenterCrop(88),
FunctionalModule(lambda x: x.transpose(0, 1)),
torchvision.transforms.Grayscale(),
FunctionalModule(lambda x: x.transpose(0, 1)),
torchvision.transforms.Normalize(0.421, 0.165),
)
self.valid_audio_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: x),
)
def __call__(self, samples: List):
targets, target_lengths = _extract_labels(self.sp_model, samples)
if self.args.md == "a":
audios, audio_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return Batch(audios, audio_lengths, targets, target_lengths)
if self.args.md == "v":
videos, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return Batch(videos, video_lengths, targets, target_lengths)
if self.args.md == "av":
audios, videos, audio_lengths, video_lengths = _extract_features(
self.valid_video_pipeline, self.valid_audio_pipeline, samples, self.args
)
return AVBatch(audios, videos, audio_lengths, video_lengths, targets, target_lengths)
class TestTransform:
def __init__(self, sp_model_path: str, args):
self.val_transforms = ValTransform(sp_model_path, args)
def __call__(self, sample):
return self.val_transforms([sample]), [sample]
def get_data_module(args, sp_model_path, max_frames=1800):
train_transform = TrainTransform(sp_model_path=sp_model_path, args=args)
val_transform = ValTransform(sp_model_path=sp_model_path, args=args)
test_transform = TestTransform(sp_model_path=sp_model_path, args=args)
return LRS3DataModule(
args=args,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
max_frames=max_frames,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment