Commit 928248d7 authored by Zhaoheng Ni's avatar Zhaoheng Ni
Browse files

Improve hubert recipe for pre-training and fine-tuning (#2744)

Summary:
following pr https://github.com/pytorch/audio/issues/2716
- For preprocessing
  - The HuBERT feature takes lots of memory which may not fit some machines. Enable to use a subset of feature for training a k-means model.

- For pre-training
  - Normalize the loss based on the total number of masked frames across all GPUs.
  - Use mixed precision training. fp16 is not well supported in pytorch_lightning.
  - Log accuracies of masked/unmasked frames during training.
  - Clip the gradients with norm `10.0`.

- For ASR fine-tuning
  - Normalize the loss based on the total number of batches across all GPUs, same as in the conformer recipe of TorchAudio.
  - Use mixed precision training.
  - Add "|" after the end of transcription to capture the silence/word termination, same as in fairseq recipe.

- Update the WER results on LibriSpeech dev and test sets.

|                   | WER% (Viterbi)|  WER% (KenLM) |
|:-----------------:|--------------:|--------------:|
| dev-clean         |       10.9    |       4.2     |
| dev-other         |       17.5    |       9.4     |
| test-clean        |       10.9    |       4.4     |
| test-other        |       17.8    |       9.5     |

Pull Request resolved: https://github.com/pytorch/audio/pull/2744

Reviewed By: carolineechen

Differential Revision: D40282322

Pulled By: nateanl

fbshipit-source-id: 4723584c912e70e8970149fe09de005385eaab90
parent 97baba1b
......@@ -29,7 +29,7 @@ After the first iteration of pre-training, the intermediate transformer layer's
Sample SLURM command for the second iteration of pre-preprocessing. The 6-th transformer layer's output is used as the input feature for training KMeans model. Note that the number of clusters is increased to 500 to improve the performance.
```
srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type hubert --exp-dir ./exp --layer-index 6 --checkpoint-path ./exp_iter1/checkpoints_librispeech_hubert_pretrain_base/xxx.ckpt --num-cluster 500
srun --cpus-per-task=24 python preprocess.py --root-dir /home/datasets --feat-type hubert --exp-dir ./exp --layer-index 6 --checkpoint-path ./exp_iter1/ --num-rank 40 checkpoints_librispeech_hubert_pretrain_base/xxx.ckpt --num-cluster 500 --percent 0.1
```
### Pre-training (2nd iteration)
......@@ -67,9 +67,9 @@ srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_fi
### CTC Decoding with language model
torchaudio provides a CTCDecoder feature that is based on [Flashlight](https://github.com/flashlight/flashlight). The decoder supports KenLM language model. Use `--use-lm` to enable CTC decoding with KenLM 4-gram language model.
Sample SLURM command for evaluation with KenLM language model:
Sample SLURM command for evaluation with KenLM language model (use the checkpoint that has the lowest validation loss):
```
srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=109-step\=19999.ckpt --split test-clean --use-lm --beam-size 1500 --lm-weight 2.46 --word-score -0.59
srun python evaluate.py --librispeech_path /root/datasets/ --checkpoint ./exp_finetune/checkpoints_hubert_pretrain_base/epoch\=106-step\=19500.ckpt --split test-clean --use-lm --beam-size 1500 --lm-weight 2.46 --word-score -0.59
```
### WER results
......@@ -77,7 +77,7 @@ The table below contains WER results for fine-tuning HuBERT Base model on `10h`
| | WER% (Viterbi)| WER% (KenLM) |
|:-----------------:|--------------:|--------------:|
| dev-clean | 10.7 | 4.4 |
| dev-other | 18.3 | 9.7 |
| test-clean | 10.8 | 4.4 |
| test-other | 18.5 | 10.1 |
| dev-clean | 10.9 | 4.2 |
| dev-other | 17.5 | 9.4 |
| test-clean | 10.9 | 4.4 |
| test-other | 17.8 | 9.5 |
......@@ -463,6 +463,8 @@ class CollateFnLibriLightLimited:
label2id = _get_label2id()
for sample in batch:
waveform, transcript = sample[0], sample[2]
# add one "|" symbol after the end of transcription as the word termination
transcript = transcript + "|"
label = torch.tensor([label2id[e] for e in transcript.replace(" ", "|").upper()])
audio_length = waveform.size(1)
label_length = label.size(0)
......
import argparse
import logging
from typing import Dict, List, Optional
from typing import Dict, List
import torch
import torch.nn.functional as F
......@@ -36,10 +36,9 @@ def _viterbi_decode(emission: torch.Tensor, id2token: Dict, blank_idx: int = 0)
Returns:
(List of str): The decoding result. List of string in lower case.
"""
hypothesis = F.log_softmax(emission, dim=-1)
hypothesis = hypothesis.argmax(-1).unique_consecutive()
hypothesis = emission.argmax(-1).unique_consecutive()
hypothesis = hypothesis[hypothesis != blank_idx]
hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ")
hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ").strip()
return hypothesis.split()
......@@ -47,7 +46,7 @@ def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]:
"""Run CTC decoding with a KenLM language model.
Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens).
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions `(..., time, num_tokens)`.
decoder (CTCDecoder): The initialized CTCDecoder.
Returns:
......@@ -55,13 +54,19 @@ def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]:
"""
hypothesis = decoder(emission)
hypothesis = hypothesis[0][0].words
hypothesis = [word for word in hypothesis if word != " "]
return hypothesis
def run_inference(args):
if args.use_gpu:
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Load the fine-tuned HuBERTPretrainModel from checkpoint.
model = _load_checkpoint(args.checkpoint)
model.eval()
model.eval().to(device)
if args.use_lm:
# get decoder files
......@@ -92,13 +97,14 @@ def run_inference(args):
transcript = transcript.strip().lower().strip().replace("\n", "")
with torch.inference_mode():
emission, _ = model(waveform)
emission, _ = model(waveform.to(device))
emission = F.log_softmax(emission, dim=-1)
if args.use_lm:
hypothesis = _ctc_decode(emission, decoder)
hypothesis = _ctc_decode(emission.cpu(), decoder)
else:
hypothesis = _viterbi_decode(emission, id2token)
total_edit_distance += torchaudio.functional.edit_distance(transcript.split(), hypothesis)
total_edit_distance += torchaudio.functional.edit_distance(hypothesis, transcript.split())
total_length += len(transcript.split())
if idx % 100 == 0:
......@@ -138,9 +144,9 @@ def _parse_args():
)
parser.add_argument(
"--beam-size-token",
type=Optional[int],
default=None,
help="Number of tokens to consider at each beam search step. (Default: None)",
type=int,
default=29,
help="Number of tokens to consider at each beam search step. (Default: 29)",
)
parser.add_argument(
"--beam-threshold", type=int, default=100, help="Beam threshold for pruning hypotheses. (Default: 100)"
......@@ -161,6 +167,7 @@ def _parse_args():
"--unk-score", type=float, default=float("-inf"), help="Unknown word insertion score. (Default: -inf)"
)
parser.add_argument("--sil-score", type=float, default=0, help="Silence insertion score. (Default: 0)")
parser.add_argument("--use-gpu", action="store_true", help="Whether to use GPU for decoding.")
parser.add_argument("--debug", action="store_true", help="Whether to use debug level for logging.")
return parser.parse_args()
......
......@@ -16,6 +16,7 @@ from lightning import HuBERTFineTuneModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
logger = logging.getLogger(__name__)
......@@ -29,10 +30,11 @@ class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
def run_train(args):
seed_everything(1337)
checkpoint_dir = args.exp_dir / f"checkpoints_{args.model_name}"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
monitor="val_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
......@@ -40,7 +42,7 @@ def run_train(args):
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
monitor="train_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
......@@ -60,7 +62,8 @@ def run_train(args):
replace_sampler_ddp=False,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
accumulate_grad_batches=args.accumulate_grad_batches,
val_check_interval=500,
check_val_every_n_epoch=None,
)
model = HuBERTFineTuneModule(
......@@ -73,6 +76,7 @@ def run_train(args):
mask_prob=args.mask_prob,
mask_channel_prob=args.mask_channel_prob,
mask_channel_length=args.mask_channel_length,
num_classes=args.num_classes,
aux_num_out=args.aux_num_out,
checkpoint=args.checkpoint,
dataset_path=args.dataset_path,
......@@ -87,7 +91,7 @@ def run_train(args):
hold_updates=args.hold_updates,
decay_updates=args.decay_updates,
)
trainer.fit(model)
trainer.fit(model, ckpt_path=args.resume_checkpoint)
def _parse_args():
......@@ -101,6 +105,18 @@ def _parse_args():
required=True,
help="Path to the LibriSpeech and LibriLightLimited datasets.",
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the pre-trained HuBERTPretrainModel checkpoint as the initialization.",
)
parser.add_argument(
"--resume-checkpoint",
default=None,
type=str,
help="The path to the checkpoint to resume the fine-tuning if training fails in the middle.",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp_finetune"),
......@@ -141,7 +157,7 @@ def _parse_args():
)
parser.add_argument(
"--encoder-layer-drop",
default=0.1,
default=0.05,
type=float,
help="Probability to drop each encoder layer during training. (Default: 0.1)",
)
......@@ -164,10 +180,11 @@ def _parse_args():
help="Minimum space between spans (if no overlap is enabled) for channel masking." "(Default: 64)",
)
parser.add_argument(
"--accumulate-grad-batches",
default=1,
"--num-classes",
choices=[100, 500],
type=int,
help="Number of batches to accumulate the gradients during training. (Default: 1)",
default=500,
help="The ``num_class`` in the pre-trained checkpoint. (Default: 500)",
)
parser.add_argument(
"--aux-num-out",
......@@ -176,13 +193,7 @@ def _parse_args():
help="The dimension of linear layer for CTC training. (Default: 29)",
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the pre-trained HuBERTPretrainModel checpoint.",
)
parser.add_argument(
"--learning-rate", default=1e-4, type=float, help="The learning rate of Adam optimizer. (Default: 2e-5)"
"--learning-rate", default=5e-5, type=float, help="The learning rate of Adam optimizer. (Default: 5e-5)"
)
parser.add_argument(
"--betas",
......@@ -198,7 +209,7 @@ def _parse_args():
)
parser.add_argument(
"--weight-decay",
default=1e-6,
default=0.0,
type=float,
help="Weight decay (L2 penalty) (Default: 0.0)",
)
......
import math
from typing import Tuple
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
......@@ -94,6 +94,29 @@ class TriStageLRScheduler(torch.optim.lr_scheduler._LRScheduler):
return [base_lr * self.final_lr_scale for base_lr in self.base_lrs]
def _compute_accuracy(logits: torch.Tensor):
with torch.no_grad():
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
def _reset_stats():
return {
"train": {
"correct": 0.0,
"count": 0.0,
},
"val": {
"correct": 0.0,
"count": 0.0,
},
}
class HuBERTPreTrainModule(LightningModule):
def __init__(
self,
......@@ -109,6 +132,7 @@ class HuBERTPreTrainModule(LightningModule):
betas: Tuple[float, float],
eps: float,
weight_decay: float,
clip_norm: Optional[float],
warmup_updates: int,
max_updates: int,
):
......@@ -124,29 +148,71 @@ class HuBERTPreTrainModule(LightningModule):
self.model = torchaudio.models.hubert_pretrain_xlarge()
else:
raise ValueError(f"Unsupported model name: {model_name}")
self.automatic_optimization = False
self.scaler = torch.cuda.amp.GradScaler()
self.loss = hubert_loss
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay
)
self.clip_norm = clip_norm
self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates)
self.dataset = dataset
self.dataset_path = dataset_path
self.feature_type = feature_type
self.seconds_per_batch = seconds_per_batch
self.mask_stats = _reset_stats()
self.unmask_stats = _reset_stats()
self.nan_loss_count = 0.0
def _step(self, batch: Batch, batch_idx, step_type):
if batch is None:
return None
return None, None
waveforms, labels, audio_lengths = batch
logit_m, logit_u, feature_penalty = self.model(
waveforms,
labels,
audio_lengths,
)
if step_type == "val":
with torch.no_grad():
logit_m, logit_u, feature_penalty = self.model(
waveforms,
labels,
audio_lengths,
)
else:
logit_m, logit_u, feature_penalty = self.model(
waveforms,
labels,
audio_lengths,
)
loss = self.loss(logit_m, logit_u, feature_penalty)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
if not torch.isinf(loss) and not torch.isnan(loss):
self.log(f"{step_type}_loss", loss.item() / logit_m.size(0), on_step=True, on_epoch=True)
else:
self.nan_loss_count += 1
self.log("nan_loss_count", self.nan_loss_count, on_step=True, on_epoch=True)
# log accuracies of masked and unmasked frames
correct_m, count_m = _compute_accuracy(logit_m)
correct_u, count_u = _compute_accuracy(logit_u)
self.mask_stats[step_type]["correct"] += correct_m
self.mask_stats[step_type]["count"] += count_m
self.unmask_stats[step_type]["correct"] += correct_u
self.unmask_stats[step_type]["count"] += count_u
self.log(
f"{step_type}_masked_accuracy",
self.mask_stats[step_type]["correct"] / self.mask_stats[step_type]["count"],
on_step=True,
on_epoch=True,
sync_dist=True,
prog_bar=step_type == "train",
)
self.log(
f"{step_type}_unmasked_accuracy",
self.unmask_stats[step_type]["correct"] / self.unmask_stats[step_type]["count"],
on_step=True,
on_epoch=True,
sync_dist=True,
prog_bar=step_type == "train",
)
return loss, logit_m.size(0)
def configure_optimizers(self):
return (
......@@ -160,20 +226,68 @@ class HuBERTPreTrainModule(LightningModule):
)
def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train")
"""Custom training step with loss normalization and automatic mixed precision training.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / num_frames), where num_frames is
the sum of masked frames across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / num_frames.
- Update parameters on each GPU.
Doing so allows us to account for the variability in number of masked frames in
variable-length sequential data.
"""
opt = self.optimizers()
opt.zero_grad()
with torch.cuda.amp.autocast(enabled=False):
loss, num_frame = self._step(batch, batch_idx, "train")
if torch.isinf(loss) or torch.isnan(loss):
opt.zero_grad()
return None
# normalize the loss based on the sum of num_frame across all GPUs
num_frames = self.all_gather(num_frame)
self.log("Gathered number of frames", num_frames.float().sum(), on_step=True, on_epoch=True)
loss *= num_frames.size(0) / num_frames.sum() # world size / num_frames
# backward the loss and clip the gradients
loss = self.scaler.scale(loss)
self.manual_backward(loss)
self.scaler.unscale_(opt)
if self.clip_norm is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm)
# optimization
self.scaler.step(opt)
sch = self.lr_schedulers()
sch.step()
self.scaler.update()
return loss
def validation_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "val")
return self._step(batch, batch_idx, "val")[0]
def on_validation_end(self):
self.mask_stats = _reset_stats()
self.unmask_stats = _reset_stats()
def train_dataloader(self):
dataset = HuBERTDataSet(self.dataset_path, self.dataset, "train")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=10000,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
shuffle=False,
)
sampler = DistributedBatchSampler(sampler, shuffle=True)
sampler.set_epoch(self.current_epoch)
......@@ -217,6 +331,7 @@ class HuBERTFineTuneModule(LightningModule):
mask_prob: float,
mask_channel_prob: float,
mask_channel_length: float,
num_classes: int,
aux_num_out: int,
checkpoint: str,
dataset_path: str,
......@@ -243,6 +358,7 @@ class HuBERTFineTuneModule(LightningModule):
mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length,
num_classes=num_classes,
)
elif model_name == "hubert_large":
self.model = torchaudio.models.hubert_pretrain_large(
......@@ -254,6 +370,7 @@ class HuBERTFineTuneModule(LightningModule):
mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length,
num_classes=num_classes,
)
elif model_name == "hubert_xlarge":
self.model = torchaudio.models.hubert_pretrain_xlarge(
......@@ -265,6 +382,7 @@ class HuBERTFineTuneModule(LightningModule):
mask_prob=mask_prob,
mask_channel_prob=mask_channel_prob,
mask_channel_length=mask_channel_length,
num_classes=num_classes,
)
else:
raise ValueError(f"Unsupported model name: {model_name}.")
......@@ -274,7 +392,7 @@ class HuBERTFineTuneModule(LightningModule):
p.requires_grad = False
self.loss_fn = torch.nn.CTCLoss(blank=0, reduction="sum", zero_infinity=True)
self.optimizer = torch.optim.AdamW(
list(self.aux.parameters()) + list(self.model.wav2vec2.encoder.parameters()),
list(self.aux.parameters()) + list(self.model.parameters()),
lr=learning_rate,
betas=betas,
eps=adam_eps,
......@@ -285,16 +403,18 @@ class HuBERTFineTuneModule(LightningModule):
self.dataset_path = dataset_path
self.seconds_per_batch = seconds_per_batch
self.subset = subset
self.automatic_optimization = False
self.scaler = torch.cuda.amp.GradScaler()
def _load_checkpoint(self, checkpoint):
# load pretrain model
# load pretrain model from checkpoint
state_dict = torch.load(checkpoint, map_location=torch.device("cpu"))
state_dict = state_dict["state_dict"]
s = {}
for k in state_dict:
if "wav2vec2" in k:
s[k.replace("model.wav2vec2.", "")] = state_dict[k]
self.model.wav2vec2.load_state_dict(s)
if "model." in k:
s[k.replace("model.", "")] = state_dict[k]
self.model.load_state_dict(s)
def _step(self, batch: Batch_FineTune, batch_idx, step_type):
if batch is None:
......@@ -315,8 +435,6 @@ class HuBERTFineTuneModule(LightningModule):
x, _ = self.model.mask_generator(x, padding_mask)
x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
logits = self.aux(x)
logits[padding_mask][..., 0] = 0
logits[padding_mask][..., 1:] = float("-inf")
log_probs = F.log_softmax(logits, dim=-1)
log_probs = log_probs.transpose(0, 1)
loss = self.loss_fn(
......@@ -325,7 +443,7 @@ class HuBERTFineTuneModule(LightningModule):
out_len,
label_lengths,
)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
self.log(f"{step_type}_loss", loss.item() / waveforms.size(0), on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
......@@ -339,7 +457,47 @@ class HuBERTFineTuneModule(LightningModule):
)
def training_step(self, batch: Batch_FineTune, batch_idx):
return self._step(batch, batch_idx, "train")
"""Custom training step with loss normalization and automatic mixed precision training.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers()
opt.zero_grad()
with torch.cuda.amp.autocast(enabled=False):
loss = self._step(batch, batch_idx, "train")
# normalize the loss based on the sum of batch_sie across all GPUs
batch_size = batch[0].size(0)
batch_sizes = self.all_gather(batch_size)
self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
# backward the loss and clip the gradients
loss = self.scaler.scale(loss)
self.manual_backward(loss)
self.scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
# optimization
self.scaler.step(opt)
sch = self.lr_schedulers()
sch.step()
self.scaler.update()
def validation_step(self, batch: Batch_FineTune, batch_idx):
return self._step(batch, batch_idx, "val")
......
......@@ -60,6 +60,12 @@ def _parse_args():
type=int,
help="The number of clusters for KMeans clustering.",
)
parser.add_argument(
"--percent",
default=-1,
type=float,
help="The percent of data for KMeans clustering. If negative, use all data. (Default: -1)",
)
args = parser.parse_args()
return args
......@@ -114,6 +120,7 @@ def main(args):
args.num_rank,
km_dir,
args.num_cluster,
args.percent,
)
# Predict labels for MFCC or HuBERT features
......
......@@ -7,11 +7,12 @@ python train.py --dataset-path ./exp/data/mfcc/ --feature-type mfcc --num-classe
import logging
import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter
from typing import Optional, Tuple
from typing import Tuple
from lightning import HuBERTPreTrainModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
logger = logging.getLogger(__name__)
......@@ -25,10 +26,11 @@ class _Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
def run_train(args):
seed_everything(1337)
checkpoint_dir = args.exp_dir / f"checkpoints_{args.dataset}_{args.model_name}"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
monitor="val_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
......@@ -36,7 +38,7 @@ def run_train(args):
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
monitor="train_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
......@@ -54,7 +56,6 @@ def run_train(args):
accelerator="gpu",
strategy="ddp",
replace_sampler_ddp=False,
gradient_clip_val=args.clip_norm,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
......@@ -71,6 +72,7 @@ def run_train(args):
betas=args.betas,
eps=args.eps,
weight_decay=args.weight_decay,
clip_norm=args.clip_norm,
warmup_updates=args.warmup_updates,
max_updates=args.max_updates,
)
......@@ -90,7 +92,7 @@ def _parse_args():
)
parser.add_argument(
"--resume-checkpoint",
type=Optional[pathlib.Path],
type=pathlib.Path,
default=None,
help="Path to the feature and label directories. (Default: None)",
)
......@@ -159,9 +161,9 @@ def _parse_args():
)
parser.add_argument(
"--clip-norm",
default=None,
type=Optional[float],
help="The gradient norm value to clip. (Default: None)",
default=10.0,
type=float,
help="The gradient norm value to clip. (Default: 10.0)",
)
parser.add_argument(
"--num-nodes",
......
......@@ -60,12 +60,6 @@ def extract_feature_mfcc(
).to(device)
waveform = waveform[0].to(device)
mfccs = feature_extractor(waveform) # (freq, time)
# mfccs = torchaudio.compliance.kaldi.mfcc(
# waveform=waveform,
# sample_frequency=sample_rate,
# use_energy=False,
# ) # (time, freq)
# mfccs = mfccs.transpose(0, 1) # (freq, time)
deltas = torchaudio.functional.compute_deltas(mfccs)
ddeltas = torchaudio.functional.compute_deltas(deltas)
concat = torch.cat([mfccs, deltas, ddeltas], dim=0)
......
......@@ -20,12 +20,14 @@ def load_feature(
feat_dir: Path,
split: str,
num_rank: int,
percent: float,
) -> Tuple[Tensor, Tensor]:
r"""Loading features from pre-saved `.pt` files.
Args:
feat_dir (Path): The directory that stores the feature files.
split (str): The split of data. Options: [``train``, ``valid``].
num_rank (int): The number of ranks for multi-processing in feature extraction.
percent (float): The percent of data for training k-means model. If negative, use all data for training.
Returns:
(Tensor, Tensor)
......@@ -37,9 +39,23 @@ def load_feature(
for rank in range(1, num_rank + 1):
feat_path, len_path = _get_feat_lens_paths(feat_dir, split, rank, num_rank)
feat = torch.load(feat_path)
length = torch.load(len_path)
feats.append(feat)
lens.append(length)
length = torch.load(len_path).int()
if percent < 0:
feats.append(feat)
lens.append(length)
else:
offsets = [0] + torch.cumsum(length, dim=0, dtype=torch.int).tolist()
nsample = int(length.shape[0] * percent)
indices = torch.randperm(length.shape[0])[0:nsample]
indices = torch.sort(indices)[0]
mask = []
for i in range(indices.shape[0]):
index = indices[i]
mask += list(range(offsets[index], offsets[index] + length[index]))
mask = torch.tensor(mask, dtype=torch.int)
feat = torch.index_select(feat, 0, mask)
feats.append(feat)
lens.append(length[indices])
feats = torch.cat(feats)
lens = torch.cat(lens)
return feats, lens
......@@ -51,6 +67,7 @@ def learn_kmeans(
num_rank: int,
km_dir: Path,
n_clusters: int,
percent: float = -1,
init: str = "k-means++",
max_iter: int = 100,
batch_size: int = 10000,
......@@ -66,6 +83,8 @@ def learn_kmeans(
num_rank (int): The number of ranks for multi-processing in feature extraction.
km_dir (Path): The directory to store the KMeans clustering model.
n_clusters (int): The number of clusters.
percent (float): The percent of data for training k-means model.
If negative, use all data for training. (Default: -1)
init (str, optional): Method for initialization. Options: [``k-means++``, ``random``].
(Default: ``k-means++``)
max_iter (int, optional): Maximum number of iterations over the complete dataset. (Default: 100)
......@@ -102,6 +121,7 @@ def learn_kmeans(
feat_dir,
split,
num_rank,
percent,
)
feats = feats.numpy()
km_model.fit(feats)
......@@ -157,19 +177,16 @@ def get_km_label(
km_path = _get_model_path(km_dir)
label_path = label_dir / f"label_{split}.pt"
apply_kmeans = ApplyKmeans(km_path, device)
feats, lens = load_feature(
feat_dir,
split,
num_rank,
)
feats = feats
lens = lens.long()
offset = 0
assert feats.shape[0] == lens.sum()
with open(label_path, "w") as f:
for i in range(lens.shape[0]):
feat = feats[offset : offset + lens[i]].to(device)
offset += lens[i]
label = apply_kmeans(feat).tolist()
f.write(" ".join(map(str, label)) + "\n")
for rank in range(1, num_rank + 1):
offset = 0
feat_path, len_path = _get_feat_lens_paths(feat_dir, split, rank, num_rank)
feats = torch.load(feat_path)
length = torch.load(len_path).int()
assert feats.shape[0] == length.sum()
labels = apply_kmeans(feats.to(device)).tolist()
for i in range(length.shape[0]):
label = labels[offset : offset + length[i]]
offset += length[i]
f.write(" ".join(map(str, label)) + "\n")
_LG.info("Finished predicting labels successfully")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment