Commit 0092aa3c authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix hubert fine-tuning recipe bugs (#2588)

Summary:
- The optimizer in fine-tuning recipe should also be `AdamW`. See https://github.com/pytorch/audio/pull/2412
- Fix the import of `DistributedBatchSampler` in hubert dataset
- Fix `dataset_path` in fine-tuning module.

Pull Request resolved: https://github.com/pytorch/audio/pull/2588

Reviewed By: carolineechen

Differential Revision: D38243423

Pulled By: nateanl

fbshipit-source-id: badc88ce9eddfd71270201a65ae89433fae2733f
parent d84ce3b2
......@@ -50,7 +50,7 @@ Sample SLURM command for fine-tuning on `10h` subset of `LibriLightLimited` data
```
srun --gpus-per-node=1 -N 1 --ntasks-per-node=1 --cpus-per-task=10 \
python finetune.py --dataset-path /root/datasets/ --exp-dir ./exp_finetune \
--checkpoint /exp_iter2/checkpoints_librispeech_hubert_pretrain_base/epoch=361-step=399999.ckpt \
--checkpoint ./exp_iter2/checkpoints_librispeech_hubert_pretrain_base/epoch=361-step=399999.ckpt \
--gpus 1 --debug --warmup-updates 2000 --hold-updates 8000 --decay-updates 10000 --max-updates 20000 --learning-rate 5e-5
```
......
......@@ -4,6 +4,7 @@ from .hubert_dataset import (
BucketizeBatchSampler,
CollateFnHubert,
CollateFnLibriLightLimited,
DistributedBatchSampler,
HuBERTDataSet,
)
......@@ -14,5 +15,6 @@ __all__ = [
"BucketizeBatchSampler",
"CollateFnHubert",
"CollateFnLibriLightLimited",
"DistributedBatchSampler",
"HuBERTDataSet",
]
......@@ -75,7 +75,7 @@ def run_train(args):
mask_channel_length=args.mask_channel_length,
aux_num_out=args.aux_num_out,
checkpoint=args.checkpoint,
dataset_paths=args.dataset_path,
dataset_path=args.dataset_path,
seconds_per_batch=args.seconds_per_batch,
subset=args.subset,
learning_rate=args.learning_rate,
......
......@@ -273,7 +273,7 @@ class HuBERTFineTuneModule(LightningModule):
for p in self.model.wav2vec2.feature_extractor.parameters():
p.requires_grad = False
self.loss_fn = torch.nn.CTCLoss(blank=0, reduction="sum", zero_infinity=True)
self.optimizer = torch.optim.Adam(
self.optimizer = torch.optim.AdamW(
list(self.aux.parameters()) + list(self.model.wav2vec2.encoder.parameters()),
lr=learning_rate,
betas=betas,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment