Commit 69467ea5 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Refactor LibriSpeech Conformer RNN-T recipe (#2366)

Summary:
Modifies the example LibriSpeech Conformer RNN-T recipe as follows:
- Moves data loading and transforms logic from lightning module to data module (improves generalizability and reusability of lightning module and data module).
- Moves transforms logic from dataloader collator function to dataset (resolves dataloader multiprocessing issues on certain platforms).
- Replaces lambda functions with `partial` equivalents (resolves pickling issues in certain runtime environments).
- Modifies training script to allow for specifying path model checkpoint to restart training from.

Pull Request resolved: https://github.com/pytorch/audio/pull/2366

Reviewed By: mthrok

Differential Revision: D36305028

Pulled By: hwangjeff

fbshipit-source-id: 0b768da5d5909136c55418bf0a3c2ddd0c5683ba
parent 93c26d63
import os
import random
import torch
import torchaudio
from pytorch_lightning import LightningDataModule
def _batch_by_token_count(idx_target_lengths, max_tokens, batch_size=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > max_tokens or (batch_size and len(current_batch) == batch_size):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}
def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
return [_target_length(fileid) for fileid in librispeech_dataset._walker]
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset,
lengths,
max_tokens,
num_buckets,
shuffle=False,
batch_size=None,
):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_tokens >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets],
max_tokens,
batch_size=batch_size,
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn
def __getitem__(self, idx):
return self.transform_fn(self.dataset[idx])
def __len__(self):
return len(self.dataset)
class LibriSpeechDataModule(LightningDataModule):
def __init__(
self,
*,
librispeech_path,
train_transform,
val_transform,
test_transform,
max_tokens=700,
batch_size=2,
train_num_buckets=50,
train_shuffle=True,
num_workers=10,
):
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
self.max_tokens = max_tokens
self.batch_size = batch_size
self.train_num_buckets = train_num_buckets
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
)
return dataloader
def val_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
...@@ -4,7 +4,7 @@ from argparse import ArgumentParser ...@@ -4,7 +4,7 @@ from argparse import ArgumentParser
import torch import torch
import torchaudio import torchaudio
from lightning import ConformerRNNTModule from lightning import ConformerRNNTModule, get_data_module
logger = logging.getLogger() logger = logging.getLogger()
...@@ -15,19 +15,15 @@ def compute_word_level_distance(seq1, seq2): ...@@ -15,19 +15,15 @@ def compute_word_level_distance(seq1, seq2):
def run_eval(args): def run_eval(args):
model = ConformerRNNTModule.load_from_checkpoint( model = ConformerRNNTModule.load_from_checkpoint(args.checkpoint_path, sp_model_path=str(args.sp_model_path)).eval()
args.checkpoint_path, data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path))
librispeech_path=str(args.librispeech_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
).eval()
if args.use_cuda: if args.use_cuda:
model = model.to(device="cuda") model = model.to(device="cuda")
total_edit_distance = 0 total_edit_distance = 0
total_length = 0 total_length = 0
dataloader = model.test_dataloader() dataloader = data_module.test_dataloader()
with torch.no_grad(): with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader): for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][2] actual = sample[0][2]
...@@ -45,6 +41,7 @@ def cli_main(): ...@@ -45,6 +41,7 @@ def cli_main():
"--checkpoint-path", "--checkpoint-path",
type=pathlib.Path, type=pathlib.Path,
help="Path to checkpoint to use for evaluation.", help="Path to checkpoint to use for evaluation.",
required=True,
) )
parser.add_argument( parser.add_argument(
"--global-stats-path", "--global-stats-path",
...@@ -56,11 +53,13 @@ def cli_main(): ...@@ -56,11 +53,13 @@ def cli_main():
"--librispeech-path", "--librispeech-path",
type=pathlib.Path, type=pathlib.Path,
help="Path to LibriSpeech datasets.", help="Path to LibriSpeech datasets.",
required=True,
) )
parser.add_argument( parser.add_argument(
"--sp-model-path", "--sp-model-path",
type=pathlib.Path, type=pathlib.Path,
help="Path to SentencePiece model.", help="Path to SentencePiece model.",
required=True,
) )
parser.add_argument( parser.add_argument(
"--use-cuda", "--use-cuda",
......
import json
import logging import logging
import math import math
import os
import random
from collections import namedtuple
from typing import List, Tuple from typing import List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from pytorch_lightning import LightningModule, seed_everything from data_module import LibriSpeechDataModule
from pytorch_lightning import LightningModule
from torchaudio.models import Hypothesis, RNNTBeamSearch from torchaudio.models import Hypothesis, RNNTBeamSearch
from torchaudio.prototype.models import conformer_rnnt_base from torchaudio.prototype.models import conformer_rnnt_base
from transforms import Batch, TrainTransform, ValTransform, TestTransform
logger = logging.getLogger() logger = logging.getLogger()
seed_everything(1)
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
_expected_spm_vocab_size = 1023 _expected_spm_vocab_size = 1023
def _piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
def _batch_by_token_count(idx_target_lengths, token_limit, sample_limit=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > token_limit or (sample_limit and len(current_batch) == sample_limit):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}
def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
return [_target_length(fileid) for fileid in librispeech_dataset._walker]
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(self, dataset, lengths, max_token_limit, num_buckets, shuffle=False, sample_limit=None):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_token_limit >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets], max_token_limit, sample_limit=sample_limit
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
class WarmupLR(torch.optim.lr_scheduler._LRScheduler): class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
r"""Learning rate scheduler that performs linear warmup and exponential annealing. r"""Learning rate scheduler that performs linear warmup and exponential annealing.
...@@ -189,13 +74,7 @@ def post_process_hypos( ...@@ -189,13 +74,7 @@ def post_process_hypos(
class ConformerRNNTModule(LightningModule): class ConformerRNNTModule(LightningModule):
def __init__( def __init__(self, sp_model_path):
self,
*,
librispeech_path: str,
sp_model_path: str,
global_stats_path: str,
):
super().__init__() super().__init__()
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path) self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
...@@ -214,65 +93,8 @@ class ConformerRNNTModule(LightningModule): ...@@ -214,65 +93,8 @@ class ConformerRNNTModule(LightningModule):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: x.transpose(1, 2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
GlobalStatsNormalization(global_stats_path),
)
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.automatic_optimization = False self.automatic_optimization = False
def _extract_labels(self, samples: List):
targets = [self.sp_model.encode(sample[2].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _train_collate_fn(self, samples: List):
features, feature_lengths = self._train_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _valid_collate_fn(self, samples: List):
features, feature_lengths = self._valid_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), samples
def _step(self, batch, _, step_type): def _step(self, batch, _, step_type):
if batch is None: if batch is None:
return None return None
...@@ -348,55 +170,14 @@ class ConformerRNNTModule(LightningModule): ...@@ -348,55 +170,14 @@ class ConformerRNNTModule(LightningModule):
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test") return self._step(batch, batch_idx, "test")
def train_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(dataset, lengths, 700, 50, shuffle=False, sample_limit=2)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=self._train_collate_fn,
num_workers=10,
batch_size=None,
shuffle=True,
)
return dataloader
def val_dataloader(self):
datasets = [
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(dataset, lengths, 700, 1, sample_limit=2)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._valid_collate_fn,
num_workers=10,
)
return dataloader
def test_dataloader(self): def get_data_module(librispeech_path, global_stats_path, sp_model_path):
dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean") train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn) val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
return dataloader test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
)
import pathlib import pathlib
from argparse import ArgumentParser from argparse import ArgumentParser
from lightning import ConformerRNNTModule from lightning import ConformerRNNTModule, get_data_module
from pytorch_lightning import Trainer from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.plugins import DDPPlugin
def run_train(args): def run_train(args):
seed_everything(1)
checkpoint_dir = args.exp_dir / "checkpoints" checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint( checkpoint = ModelCheckpoint(
checkpoint_dir, checkpoint_dir,
...@@ -42,16 +43,19 @@ def run_train(args): ...@@ -42,16 +43,19 @@ def run_train(args):
reload_dataloaders_every_n_epochs=1, reload_dataloaders_every_n_epochs=1,
) )
model = ConformerRNNTModule( model = ConformerRNNTModule(str(args.sp_model_path))
librispeech_path=str(args.librispeech_path), data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path))
sp_model_path=str(args.sp_model_path), trainer.fit(model, data_module, ckpt_path=args.checkpoint_path)
global_stats_path=str(args.global_stats_path),
)
trainer.fit(model)
def cli_main(): def cli_main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
default=None,
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
default=pathlib.Path("./exp"), default=pathlib.Path("./exp"),
...@@ -68,11 +72,13 @@ def cli_main(): ...@@ -68,11 +72,13 @@ def cli_main():
"--librispeech-path", "--librispeech-path",
type=pathlib.Path, type=pathlib.Path,
help="Path to LibriSpeech datasets.", help="Path to LibriSpeech datasets.",
required=True,
) )
parser.add_argument( parser.add_argument(
"--sp-model-path", "--sp-model-path",
type=pathlib.Path, type=pathlib.Path,
help="Path to SentencePiece model.", help="Path to SentencePiece model.",
required=True,
) )
parser.add_argument( parser.add_argument(
"--nodes", "--nodes",
...@@ -93,7 +99,6 @@ def cli_main(): ...@@ -93,7 +99,6 @@ def cli_main():
help="Number of epochs to train for. (Default: 120)", help="Number of epochs to train for. (Default: 120)",
) )
args = parser.parse_args() args = parser.parse_args()
run_train(args) run_train(args)
......
import json
import math
from collections import namedtuple
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
def _piecewise_linear_log(x):
x = x * _gain
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
def _extract_labels(sp_model, samples: List):
targets = [sp_model.encode(sample[2].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _extract_features(data_pipeline, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
class TrainTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.train_data_pipeline, samples)
targets, target_lengths = _extract_labels(self.sp_model, samples)
return Batch(features, feature_lengths, targets, target_lengths)
class ValTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
)
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.valid_data_pipeline, samples)
targets, target_lengths = _extract_labels(self.sp_model, samples)
return Batch(features, feature_lengths, targets, target_lengths)
class TestTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.val_transforms = ValTransform(global_stats_path, sp_model_path)
def __call__(self, sample):
return self.val_transforms([sample]), [sample]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment