Commit e9083571 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Update HuBERT/SSL training recipes to support Lightning 2.x (#3396)

Summary:
There are some BC-Breaking changes from pytorch_lightning to lightning library. The PR adjust those changes to support latest lightning library.

Pull Request resolved: https://github.com/pytorch/audio/pull/3396

Reviewed By: mthrok

Differential Revision: D46345206

Pulled By: nateanl

fbshipit-source-id: 59469c15dc5fe5466a99a5b5380eb4f98c2c633f
parent b7d3e89a
......@@ -12,11 +12,10 @@ import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter
from typing import Tuple
from lightning import HuBERTFineTuneModule
from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
from lightning_modules import HuBERTFineTuneModule
logger = logging.getLogger(__name__)
......@@ -56,10 +55,10 @@ def run_train(args):
default_root_dir=args.exp_dir,
max_steps=args.max_updates,
num_nodes=args.num_nodes,
gpus=args.gpus,
devices=args.gpus,
accelerator="gpu",
strategy="ddp",
replace_sampler_ddp=False,
strategy="ddp_find_unused_parameters_true",
use_distributed_sampler=False,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
val_check_interval=500,
......
......@@ -14,8 +14,8 @@ from dataset import (
DistributedBatchSampler,
HuBERTDataSet,
)
from lightning.pytorch import LightningModule
from loss import hubert_loss
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
......
......@@ -9,10 +9,10 @@ import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter
from typing import Tuple
from lightning import HuBERTPreTrainModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning_modules import HuBERTPreTrainModule
logger = logging.getLogger(__name__)
......@@ -52,10 +52,10 @@ def run_train(args):
default_root_dir=args.exp_dir,
max_steps=args.max_updates,
num_nodes=args.num_nodes,
gpus=args.gpus,
devices=args.gpus,
accelerator="gpu",
strategy="ddp",
replace_sampler_ddp=False,
strategy="ddp_find_unused_parameters_true",
use_distributed_sampler=False,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
......
......@@ -9,7 +9,7 @@ import torchaudio
from torch import Tensor
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
from ..lightning import Batch
from ..lightning_modules import Batch
class BucketizeBatchSampler(BatchSampler):
......
from collections import namedtuple
from typing import Callable, Optional
import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
......
......@@ -6,12 +6,11 @@ from typing import Dict, Tuple
import torch
import torchaudio.models
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from .data_modules import HuBERTDataModule
from .lightning import SSLPretrainModule
from .lightning_modules import SSLPretrainModule
from .losses import hubert_loss
from .lr_schedulers import LinearDecayLRScheduler
......@@ -102,11 +101,11 @@ def run_train(args):
num_nodes=args.num_nodes,
devices=args.gpus,
accelerator="gpu",
strategy="ddp",
strategy="ddp_find_unused_parameters_true",
precision=args.precision,
accumulate_grad_batches=args.accumulate_grad_batches,
gradient_clip_val=args.clip_norm,
replace_sampler_ddp=False,
use_distributed_sampler=False,
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment