"...text-generation-inference.git" did not exist on "935a77fb74d7e5e31579b7bf3c9263c23d6dd17b"
Unverified Commit 9c274228 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Example pipeline with wav2letter (#632)

* example pipeline, initial commit.

* removing notebook conversion artifacts.

* remove extra comments. lint.

* addressing some feedback.

* main function.

* defining args in function.

* refactor.

* lint.

* checkpoint.

* clean version to start with.

* adding more parameters.

* lint.

* cleaning full version.

* check for not None.

* cleaning.

* back -l 160

* black.

* fix runtime error.

* removing some print statements.

* add help to command line. add progress bar option.

* grouping librispeech-specific transform in subclass.

* typo.

* fix concatenation.

* typo.

* black. tqdm.

* missing transpose.

* renaming variables.

* sum cer and wer

* clip norm.

* second signal handler removed.

* cosmetic.

* default to no checkpoint.

* remove non_blocking.

* adadelta works better than sgd.

* anomaly detection.

* moving dataset to separate file.

* lint.

* move to separate module: languagemodel, decoder, metric.

* flush=True.

* renaming decoder.

* CTC Decoders.

* flush=True.

* pass length for viterbi decoder.

* progress bar. relative path.

* generalize transition matrix to n-gram. progress bar.

* choice of decoder.

* collate func.

* remove signal handling.

* adding distributed.

* lint.

* normalize w/r to length of dataset, and w/r to total number characters.

* relative cer/wer.

* clip grad parameter. momentum back but not yet used.

* Switch to SGD.

* choice of optimizer.

* scheduler.

* move to utils file.

* metric log, and utils file.

* rename metric_logger.

* stderr and stdout. simpler metric logger.

* replace by logging.

* adding time measurement in metric logger.

* fix duplicate name. remove tqdm. keep track of epoch instead and iteration instead.

* rename main file. and add readme.

* refactor distributed.

* swap example and output in readme.

* remove time from logger.

* check non-empty tensor input.

* typo in variable name and log update.

* typo.

* compute cer/wer in training too.

* typo.

* add back slurm signal capture to resubmit job.

* update levinstein distance.

* adding tests for levenstein distance.

* record error rate during iteration.

* metric logger using setitem.

* moving signal break to end of loop and return loss so far.

* typo.

* add citation.

* change default to best run.

* adding other experiment with decoders.

* remove other decoders than greedy.

* Revert "remove other decoders than greedy."

This reverts commit fb114372e89e317bf48d0b1f846c60bca8efe1ac.

* changing name of folfder.

* remove other decoders, and unused dataset class.

* rename functions to align with other pipeline.

* pick which parts to train with.

* adding specaugment to validation. note that caching prevents randomization from happening in validation.

* updating readme.

* typo in metric logging.

* Revert "typo in metric logging."

This reverts commit acac245eec250f61d2039a67933d3c01f1975ce9.

* Revert "Revert "typo in metric logging.""

This reverts commit 2c80d9691ed401044da734c40df3715dba92d0db.

* update metric logger.

* simplify metric logger implementation.

* use json dumps instead.

* group metric together.

* move function.

* lint.

* quick summary of files in folder.

* pass clip_grad explictly.

* typo in default dataset name.

* option to disable logger.

* ergonomics for distributed.

* reminder about signal handler.

* minor refactor of main in main.

* replace by not_main_rank.

* raising error if parameter not supported.

* move model before invoking DDP.

* changing log level. using python 2 style string for logging.

* dynamic augmentations.

* update metric log.

batch cer/wer metric. correct typo in time. adding other dimensions in metric.

* save learning rate even if function not available.

* add type option to model.

* add adamw.

* reduce lr on validation step or training step.

* specify hop-length and win-length.

* normalize option.

* rename parameter.

* add dropout and tweak to number of channels.

* copy model in pipeline folder for experimentation.

* fix scheduler stepping.

* fix input_type and num_features.

* waveform mode changes shape more.

* adding best character error rate with current implementation of model with mfcc.

* comment update.

* remove signal. remove custom wav2letter model.

* remove comment.

* simpler import with pandas.
parent 95d9f2d2
This is an example pipeline for speech recognition using a greedy or Viterbi CTC decoder, along with the Wav2Letter model trained on LibriSpeech, see [Wav2Letter: an End-to-End ConvNet-based Speech Recognition System](https://arxiv.org/pdf/1609.03193.pdf). Wav2Letter and LibriSpeech are available in torchaudio.
### Usage
More information about each command line parameters is available with the `--help` option. An example can be invoked as follows.
```
python main.py \
--reduce-lr-valid \
--dataset-train train-clean-100 train-clean-360 train-other-500 \
--dataset-valid dev-clean \
--batch-size 128 \
--learning-rate .6 \
--momentum .8 \
--weight-decay .00001 \
--clip-grad 0. \
--gamma .99 \
--hop-length 160 \
--n-hidden-channels 2000 \
--win-length 400 \
--n-bins 13 \
--normalize \
--optimizer adadelta \
--scheduler reduceonplateau \
--epochs 30
```
With these default parameters, we get a character error rate of 13.8% on dev-clean after 30 epochs.
### Output
The information reported at each iteration and epoch (e.g. loss, character error rate, word error rate) is printed to standard output in the form of one json per line, e.g.
```python
{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 23317.0, "total chars": 23317.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 4446.0, "total words": 4446.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 2453, "dataset length": 128.0, "iteration": 1.0, "loss": 8.712121963500977, "cumulative loss": 8.712121963500977, "average loss": 8.712121963500977, "iteration time": 41.46276903152466, "epoch time": 41.46276903152466}
{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 46005.0, "total chars": 46005.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 8762.0, "total words": 8762.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1703, "dataset length": 256.0, "iteration": 2.0, "loss": 8.918599128723145, "cumulative loss": 17.63072109222412, "average loss": 8.81536054611206, "iteration time": 1.2905676364898682, "epoch time": 42.753336668014526}
{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 70030.0, "total chars": 70030.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 13348.0, "total words": 13348.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1713, "dataset length": 384.0, "iteration": 3.0, "loss": 8.550191879272461, "cumulative loss": 26.180912971496582, "average loss": 8.726970990498861, "iteration time": 1.2109291553497314, "epoch time": 43.96426582336426}
```
One way to import the output in python with pandas is by saving the standard output to a file, and then using `pandas.read_json(filename, lines=True)`.
## Structure of pipeline
* `main.py` -- the entry point
* `ctc_decoders.py` -- the greedy CTC decoder
* `datasets.py` -- the function to split and process librispeech, a collate factory function
* `languagemodels.py` -- a class to encode and decode strings
* `metrics.py` -- the levenshtein edit distance
* `utils.py` -- functions to log metrics, save checkpoint, and count parameters
from torch import topk
class GreedyDecoder:
def __call__(self, outputs):
"""Greedy Decoder. Returns highest probability of class labels for each timestep
Args:
outputs (torch.Tensor): shape (input length, batch size, number of classes (including blank))
Returns:
torch.Tensor: class labels per time step.
"""
_, indices = topk(outputs, k=1, dim=-1)
return indices[..., 0]
import torch
from torchaudio.datasets import LIBRISPEECH
class MapMemoryCache(torch.utils.data.Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
"""
def __init__(self, dataset):
self.dataset = dataset
self._cache = [None] * len(dataset)
def __getitem__(self, n):
if self._cache[n] is not None:
return self._cache[n]
item = self.dataset[n]
self._cache[n] = item
return item
def __len__(self):
return len(self.dataset)
class Processed(torch.utils.data.Dataset):
def __init__(self, dataset, transforms, encode):
self.dataset = dataset
self.transforms = transforms
self.encode = encode
def __getitem__(self, key):
item = self.dataset[key]
return self.process_datapoint(item)
def __len__(self):
return len(self.dataset)
def process_datapoint(self, item):
transformed = item[0]
target = item[2].lower()
transformed = self.transforms(transformed)
transformed = transformed[0, ...].transpose(0, -1)
target = self.encode(target)
target = torch.tensor(target, dtype=torch.long, device=transformed.device)
return transformed, target
def split_process_librispeech(
datasets, transforms, language_model, root, folder_in_archive,
):
def create(tags, cache=True):
if isinstance(tags, str):
tags = [tags]
if isinstance(transforms, list):
transform_list = transforms
else:
transform_list = [transforms]
data = torch.utils.data.ConcatDataset(
[
Processed(
LIBRISPEECH(
root, tag, folder_in_archive=folder_in_archive, download=False,
),
transform,
language_model.encode,
)
for tag, transform in zip(tags, transform_list)
]
)
data = MapMemoryCache(data)
return data
# For performance, we cache all datasets
return tuple(create(dataset) for dataset in datasets)
def collate_factory(model_length_function, transforms=None):
if transforms is None:
transforms = torch.nn.Sequential()
def collate_fn(batch):
tensors = [transforms(b[0]) for b in batch if b]
tensors_lengths = torch.tensor(
[model_length_function(t) for t in tensors],
dtype=torch.long,
device=tensors[0].device,
)
tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
tensors = tensors.transpose(1, -1)
targets = [b[1] for b in batch if b]
target_lengths = torch.tensor(
[target.shape[0] for target in targets],
dtype=torch.long,
device=tensors.device,
)
targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
return tensors, targets, tensors_lengths, target_lengths
return collate_fn
import collections
import itertools
class LanguageModel:
def __init__(self, labels, char_blank, char_space):
self.char_space = char_space
self.char_blank = char_blank
labels = [l for l in labels]
self.length = len(labels)
enumerated = list(enumerate(labels))
flipped = [(sub[1], sub[0]) for sub in enumerated]
d1 = collections.OrderedDict(enumerated)
d2 = collections.OrderedDict(flipped)
self.mapping = {**d1, **d2}
def encode(self, iterable):
if isinstance(iterable, list):
return [self.encode(i) for i in iterable]
else:
return [self.mapping[i] + self.mapping[self.char_blank] for i in iterable]
def decode(self, tensor):
if len(tensor) > 0 and isinstance(tensor[0], list):
return [self.decode(t) for t in tensor]
else:
# not idempotent, since clean string
x = (self.mapping[i] for i in tensor)
x = "".join(i for i, _ in itertools.groupby(x))
x = x.replace(self.char_blank, "")
# x = x.strip()
return x
def __len__(self):
return self.length
import argparse
import logging
import os
import string
from datetime import datetime
from time import time
import torch
import torchaudio
from torch.optim import SGD, Adadelta, Adam, AdamW
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models.wav2letter import Wav2Letter
from ctc_decoders import GreedyDecoder
from datasets import collate_factory, split_process_librispeech
from languagemodels import LanguageModel
from metrics import levenshtein_distance
from transforms import Normalize, UnsqueezeFirst
from utils import MetricLogger, count_parameters, save_checkpoint
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--type",
metavar="T",
default="mfcc",
choices=["waveform", "mfcc"],
help="input type for model",
)
parser.add_argument(
"--freq-mask",
default=0,
type=int,
metavar="N",
help="maximal width of frequency mask",
)
parser.add_argument(
"--win-length",
default=400,
type=int,
metavar="N",
help="width of spectrogram window",
)
parser.add_argument(
"--hop-length",
default=160,
type=int,
metavar="N",
help="width of spectrogram window",
)
parser.add_argument(
"--time-mask",
default=0,
type=int,
metavar="N",
help="maximal width of time mask",
)
parser.add_argument(
"--workers",
default=0,
type=int,
metavar="N",
help="number of data loading workers",
)
parser.add_argument(
"--checkpoint",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint",
)
parser.add_argument(
"--epochs",
default=200,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch", default=0, type=int, metavar="N", help="manual epoch number"
)
parser.add_argument(
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency in epochs",
)
parser.add_argument(
"--reduce-lr-valid",
action="store_true",
help="reduce learning rate based on validation loss",
)
parser.add_argument(
"--normalize", action="store_true", help="normalize model input"
)
parser.add_argument(
"--progress-bar", action="store_true", help="use progress bar while training"
)
parser.add_argument(
"--decoder",
metavar="D",
default="greedy",
choices=["greedy"],
help="decoder to use",
)
parser.add_argument(
"--batch-size", default=128, type=int, metavar="N", help="mini-batch size"
)
parser.add_argument(
"--n-bins",
default=13,
type=int,
metavar="N",
help="number of bins in transforms",
)
parser.add_argument(
"--optimizer",
metavar="OPT",
default="adadelta",
choices=["sgd", "adadelta", "adam", "adamw"],
help="optimizer to use",
)
parser.add_argument(
"--scheduler",
metavar="S",
default="reduceonplateau",
choices=["exponential", "reduceonplateau"],
help="optimizer to use",
)
parser.add_argument(
"--learning-rate",
default=0.6,
type=float,
metavar="LR",
help="initial learning rate",
)
parser.add_argument(
"--gamma",
default=0.99,
type=float,
metavar="GAMMA",
help="learning rate exponential decay constant",
)
parser.add_argument(
"--momentum", default=0.8, type=float, metavar="M", help="momentum"
)
parser.add_argument(
"--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay"
)
parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8)
parser.add_argument("--rho", metavar="RHO", type=float, default=0.95)
parser.add_argument("--clip-grad", metavar="NORM", type=float, default=0.0)
parser.add_argument(
"--dataset-root",
type=str,
help="specify dataset root folder",
)
parser.add_argument(
"--dataset-folder-in-archive",
type=str,
help="specify dataset folder in archive",
)
parser.add_argument(
"--dataset-train",
default=["train-clean-100"],
nargs="+",
type=str,
help="select which part of librispeech to train with",
)
parser.add_argument(
"--dataset-valid",
default=["dev-clean"],
nargs="+",
type=str,
help="select which part of librispeech to validate with",
)
parser.add_argument(
"--distributed", action="store_true", help="enable DistributedDataParallel"
)
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument(
"--world-size", type=int, default=8, help="the world size to initiate DPP"
)
parser.add_argument("--jit", action="store_true", help="if used, model is jitted")
args = parser.parse_args()
logging.info(args)
return args
def setup_distributed(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
def model_length_function(tensor):
if tensor.shape[1] == 1:
# waveform mode
return int(tensor.shape[0]) // 160 // 2 + 1
return int(tensor.shape[0]) // 2 + 1
def compute_error_rates(outputs, targets, decoder, language_model, metric):
output = outputs.transpose(0, 1).to("cpu")
output = decoder(output)
# Compute CER
output = language_model.decode(output.tolist())
target = language_model.decode(targets.tolist())
print_length = 20
for i in range(2):
# Print a few examples
output_print = output[i].ljust(print_length)[:print_length]
target_print = target[i].ljust(print_length)[:print_length]
logging.info("Target: %s Output: %s", target_print, output_print)
cers = [levenshtein_distance(t, o) for t, o in zip(target, output)]
cers = sum(cers)
n = sum(len(t) for t in target)
metric["cer over target length"] = cers / n
metric["cumulative cer"] += cers
metric["total chars"] += n
metric["cumulative cer over target length"] = metric["cer"] / metric["total chars"]
# Compute WER
output = [o.split(language_model.char_space) for o in output]
target = [t.split(language_model.char_space) for t in target]
wers = [levenshtein_distance(t, o) for t, o in zip(target, output)]
wers = sum(wers)
n = sum(len(t) for t in target)
metric["wer over target length"] = wers / n
metric["cumulative wer"] += wers
metric["total words"] += n
metric["cumulative wer over target length"] = metric["wer"] / metric["total words"]
def train_one_epoch(
model,
criterion,
optimizer,
scheduler,
data_loader,
decoder,
language_model,
device,
epoch,
clip_grad,
disable_logger=False,
reduce_lr_on_plateau=False,
):
model.train()
metric = MetricLogger("train", disable=disable_logger)
metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
data_loader, maxsize=2
):
start = time()
inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
# keep batch first for data parallel
outputs = model(inputs).transpose(-1, -2).transpose(0, 1)
# CTC
# outputs: input length, batch size, number of classes (including blank)
# targets: batch size, max target length
# input_lengths: batch size
# target_lengths: batch size
loss = criterion(outputs, targets, tensors_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
if clip_grad > 0:
metric["gradient"] = torch.nn.utils.clip_grad_norm_(
model.parameters(), clip_grad
)
optimizer.step()
compute_error_rates(outputs, targets, decoder, language_model, metric)
try:
metric["lr"] = scheduler.get_last_lr()[0]
except AttributeError:
metric["lr"] = optimizer.param_groups[0]["lr"]
metric["batch size"] = len(inputs)
metric["n_channel"] = inputs.shape[1]
metric["n_time"] = inputs.shape[-1]
metric["dataset length"] += metric["batch size"]
metric["iteration"] += 1
metric["loss"] = loss.item()
metric["cumulative loss"] += metric["loss"]
metric["average loss"] = metric["cumulative loss"] / metric["iteration"]
metric["iteration time"] = time() - start
metric["epoch time"] += metric["iteration time"]
metric()
if reduce_lr_on_plateau and isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(metric["average loss"])
elif not isinstance(scheduler, ReduceLROnPlateau):
scheduler.step()
def evaluate(
model,
criterion,
data_loader,
decoder,
language_model,
device,
epoch,
disable_logger=False,
):
with torch.no_grad():
model.eval()
start = time()
metric = MetricLogger("validation", disable=disable_logger)
metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
data_loader, maxsize=2
):
inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
# keep batch first for data parallel
outputs = model(inputs).transpose(-1, -2).transpose(0, 1)
# CTC
# outputs: input length, batch size, number of classes (including blank)
# targets: batch size, max target length
# input_lengths: batch size
# target_lengths: batch size
metric["cumulative loss"] += criterion(
outputs, targets, tensors_lengths, target_lengths
).item()
metric["dataset length"] += len(inputs)
metric["iteration"] += 1
compute_error_rates(outputs, targets, decoder, language_model, metric)
metric["average loss"] = metric["cumulative loss"] / metric["iteration"]
metric["validation time"] = time() - start
metric()
return metric["average loss"]
def main(rank, args):
# Distributed setup
if args.distributed:
setup_distributed(rank, args.world_size)
not_main_rank = args.distributed and rank != 0
logging.info("Start time: %s", datetime.now())
# Explicitly set seed to make sure models created in separate processes
# start from same random weights and biases
torch.manual_seed(args.seed)
# Empty CUDA cache
torch.cuda.empty_cache()
# Change backend for flac files
torchaudio.set_audio_backend("soundfile")
# Transforms
melkwargs = {
"n_fft": args.win_length,
"n_mels": args.n_bins,
"hop_length": args.hop_length,
}
sample_rate_original = 16000
if args.type == "mfcc":
transforms = torch.nn.Sequential(
torchaudio.transforms.MFCC(
sample_rate=sample_rate_original,
n_mfcc=args.n_bins,
melkwargs=melkwargs,
),
)
num_features = args.n_bins
elif args.type == "waveform":
transforms = torch.nn.Sequential(UnsqueezeFirst())
num_features = 1
else:
raise ValueError("Model type not supported")
if args.normalize:
transforms = torch.nn.Sequential(transforms, Normalize())
augmentations = torch.nn.Sequential()
if args.freq_mask:
augmentations = torch.nn.Sequential(
augmentations,
torchaudio.transforms.FrequencyMasking(freq_mask_param=args.freq_mask),
)
if args.time_mask:
augmentations = torch.nn.Sequential(
augmentations,
torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask),
)
# Text preprocessing
char_blank = "*"
char_space = " "
char_apostrophe = "'"
labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase
language_model = LanguageModel(labels, char_blank, char_space)
# Dataset
training, validation = split_process_librispeech(
[args.dataset_train, args.dataset_valid],
[transforms, transforms],
language_model,
root=args.dataset_root,
folder_in_archive=args.dataset_folder_in_archive,
)
# Decoder
if args.decoder == "greedy":
decoder = GreedyDecoder()
else:
raise ValueError("Selected decoder not supported")
# Model
model = Wav2Letter(
num_classes=language_model.length,
input_type=args.type,
num_features=num_features,
)
if args.jit:
model = torch.jit.script(model)
if args.distributed:
n = torch.cuda.device_count() // args.world_size
devices = list(range(rank * n, (rank + 1) * n))
model = model.to(devices[0])
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices)
else:
devices = ["cuda" if torch.cuda.is_available() else "cpu"]
model = model.to(devices[0], non_blocking=True)
model = torch.nn.DataParallel(model)
n = count_parameters(model)
logging.info("Number of parameters: %s", n)
# Optimizer
if args.optimizer == "adadelta":
optimizer = Adadelta(
model.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay,
eps=args.eps,
rho=args.rho,
)
elif args.optimizer == "sgd":
optimizer = SGD(
model.parameters(),
lr=args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optimizer == "adam":
optimizer = Adam(
model.parameters(),
lr=args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optimizer == "adamw":
optimizer = AdamW(
model.parameters(),
lr=args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
else:
raise ValueError("Selected optimizer not supported")
if args.scheduler == "exponential":
scheduler = ExponentialLR(optimizer, gamma=args.gamma)
elif args.scheduler == "reduceonplateau":
scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3)
else:
raise ValueError("Selected scheduler not supported")
criterion = torch.nn.CTCLoss(
blank=language_model.mapping[char_blank], zero_infinity=False
)
# Data Loader
collate_fn_train = collate_factory(model_length_function, augmentations)
collate_fn_valid = collate_factory(model_length_function)
loader_training_params = {
"num_workers": args.workers,
"pin_memory": True,
"shuffle": True,
"drop_last": True,
}
loader_validation_params = loader_training_params.copy()
loader_validation_params["shuffle"] = False
loader_training = DataLoader(
training,
batch_size=args.batch_size,
collate_fn=collate_fn_train,
**loader_training_params,
)
loader_validation = DataLoader(
validation,
batch_size=args.batch_size,
collate_fn=collate_fn_valid,
**loader_validation_params,
)
# Setup checkpoint
best_loss = 1.0
load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint)
if args.distributed:
torch.distributed.barrier()
if load_checkpoint:
logging.info("Checkpoint: loading %s", args.checkpoint)
checkpoint = torch.load(args.checkpoint)
args.start_epoch = checkpoint["epoch"]
best_loss = checkpoint["best_loss"]
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
logging.info(
"Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]
)
else:
logging.info("Checkpoint: not found")
save_checkpoint(
{
"epoch": args.start_epoch,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
},
False,
args.checkpoint,
not_main_rank,
)
if args.distributed:
torch.distributed.barrier()
torch.autograd.set_detect_anomaly(False)
for epoch in range(args.start_epoch, args.epochs):
logging.info("Epoch: %s", epoch)
train_one_epoch(
model,
criterion,
optimizer,
scheduler,
loader_training,
decoder,
language_model,
devices[0],
epoch,
args.clip_grad,
not_main_rank,
not args.reduce_lr_valid,
)
if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:
loss = evaluate(
model,
criterion,
loader_validation,
decoder,
language_model,
devices[0],
epoch,
not_main_rank,
)
is_best = loss < best_loss
best_loss = min(loss, best_loss)
save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
},
is_best,
args.checkpoint,
not_main_rank,
)
if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(loss)
logging.info("End time: %s", datetime.now())
if args.distributed:
torch.distributed.destroy_process_group()
def spawn_main(main, args):
if args.distributed:
torch.multiprocessing.spawn(
main, args=(args,), nprocs=args.world_size, join=True
)
else:
main(0, args)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
args = parse_args()
spawn_main(main, args)
from typing import List, Union
def levenshtein_distance(r: Union[str, List[str]], h: Union[str, List[str]]):
"""
Calculate the Levenshtein distance between two lists or strings.
"""
# Initialisation
dold = list(range(len(h) + 1))
dnew = list(0 for _ in range(len(h) + 1))
# Computation
for i in range(1, len(r) + 1):
dnew[0] = i
for j in range(1, len(h) + 1):
if r[i - 1] == h[j - 1]:
dnew[j] = dold[j - 1]
else:
substitution = dold[j - 1] + 1
insertion = dnew[j - 1] + 1
deletion = dold[j] + 1
dnew[j] = min(substitution, insertion, deletion)
dnew, dold = dold, dnew
return dold[-1]
if __name__ == "__main__":
assert levenshtein_distance("abc", "abc") == 0
assert levenshtein_distance("aaa", "aba") == 1
assert levenshtein_distance("aba", "aaa") == 1
assert levenshtein_distance("aa", "aaa") == 1
assert levenshtein_distance("aaa", "aa") == 1
assert levenshtein_distance("abc", "bcd") == 2
assert levenshtein_distance(["hello", "world"], ["hello", "world", "!"]) == 1
assert levenshtein_distance(["hello", "world"], ["world", "hello", "!"]) == 2
import torch
class Normalize(torch.nn.Module):
def forward(self, tensor):
return (tensor - tensor.mean(-1, keepdim=True)) / tensor.std(-1, keepdim=True)
class UnsqueezeFirst(torch.nn.Module):
def forward(self, tensor):
return tensor.unsqueeze(0)
import json
import logging
import os
import shutil
from collections import defaultdict
import torch
class MetricLogger(defaultdict):
def __init__(self, name, print_freq=1, disable=False):
super().__init__(lambda: 0.0)
self.disable = disable
self.print_freq = print_freq
self._iter = 0
self["name"] = name
def __str__(self):
return json.dumps(self)
def __call__(self):
self._iter = (self._iter + 1) % self.print_freq
if not self.disable and not self._iter:
print(self, flush=True)
def save_checkpoint(state, is_best, filename, disable):
"""
Save the model to a temporary file first,
then copy it to filename, in case the signal interrupts
the torch.save() process.
"""
if disable:
return
if filename == "":
return
tempfile = filename + ".temp"
# Remove tempfile in case interuption during the copying from tempfile to filename
if os.path.isfile(tempfile):
os.remove(tempfile)
torch.save(state, tempfile)
if os.path.isfile(tempfile):
os.rename(tempfile, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
logging.warning("Checkpoint: saved")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment