Commit e9083571 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Update HuBERT/SSL training recipes to support Lightning 2.x (#3396)

Summary:
There are some BC-Breaking changes from pytorch_lightning to lightning library. The PR adjust those changes to support latest lightning library.

Pull Request resolved: https://github.com/pytorch/audio/pull/3396

Reviewed By: mthrok

Differential Revision: D46345206

Pulled By: nateanl

fbshipit-source-id: 59469c15dc5fe5466a99a5b5380eb4f98c2c633f
parent b7d3e89a
...@@ -12,11 +12,10 @@ import pathlib ...@@ -12,11 +12,10 @@ import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter
from typing import Tuple from typing import Tuple
from lightning import HuBERTFineTuneModule from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer from lightning_modules import HuBERTFineTuneModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,10 +55,10 @@ def run_train(args): ...@@ -56,10 +55,10 @@ def run_train(args):
default_root_dir=args.exp_dir, default_root_dir=args.exp_dir,
max_steps=args.max_updates, max_steps=args.max_updates,
num_nodes=args.num_nodes, num_nodes=args.num_nodes,
gpus=args.gpus, devices=args.gpus,
accelerator="gpu", accelerator="gpu",
strategy="ddp", strategy="ddp_find_unused_parameters_true",
replace_sampler_ddp=False, use_distributed_sampler=False,
callbacks=callbacks, callbacks=callbacks,
reload_dataloaders_every_n_epochs=1, reload_dataloaders_every_n_epochs=1,
val_check_interval=500, val_check_interval=500,
......
...@@ -14,8 +14,8 @@ from dataset import ( ...@@ -14,8 +14,8 @@ from dataset import (
DistributedBatchSampler, DistributedBatchSampler,
HuBERTDataSet, HuBERTDataSet,
) )
from lightning.pytorch import LightningModule
from loss import hubert_loss from loss import hubert_loss
from pytorch_lightning import LightningModule
from torch import Tensor from torch import Tensor
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
......
...@@ -9,10 +9,10 @@ import pathlib ...@@ -9,10 +9,10 @@ import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawDescriptionHelpFormatter
from typing import Tuple from typing import Tuple
from lightning import HuBERTPreTrainModule from lightning.pytorch import seed_everything, Trainer
from pytorch_lightning import Trainer from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything from lightning_modules import HuBERTPreTrainModule
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -52,10 +52,10 @@ def run_train(args): ...@@ -52,10 +52,10 @@ def run_train(args):
default_root_dir=args.exp_dir, default_root_dir=args.exp_dir,
max_steps=args.max_updates, max_steps=args.max_updates,
num_nodes=args.num_nodes, num_nodes=args.num_nodes,
gpus=args.gpus, devices=args.gpus,
accelerator="gpu", accelerator="gpu",
strategy="ddp", strategy="ddp_find_unused_parameters_true",
replace_sampler_ddp=False, use_distributed_sampler=False,
callbacks=callbacks, callbacks=callbacks,
reload_dataloaders_every_n_epochs=1, reload_dataloaders_every_n_epochs=1,
) )
......
...@@ -9,7 +9,7 @@ import torchaudio ...@@ -9,7 +9,7 @@ import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import BatchSampler, Dataset, DistributedSampler from torch.utils.data import BatchSampler, Dataset, DistributedSampler
from ..lightning import Batch from ..lightning_modules import Batch
class BucketizeBatchSampler(BatchSampler): class BucketizeBatchSampler(BatchSampler):
......
from collections import namedtuple from collections import namedtuple
from typing import Callable, Optional from typing import Callable, Optional
import pytorch_lightning as pl import lightning.pytorch as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
......
...@@ -6,12 +6,11 @@ from typing import Dict, Tuple ...@@ -6,12 +6,11 @@ from typing import Dict, Tuple
import torch import torch
import torchaudio.models import torchaudio.models
from pytorch_lightning import Trainer from lightning.pytorch import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
from .data_modules import HuBERTDataModule from .data_modules import HuBERTDataModule
from .lightning import SSLPretrainModule from .lightning_modules import SSLPretrainModule
from .losses import hubert_loss from .losses import hubert_loss
from .lr_schedulers import LinearDecayLRScheduler from .lr_schedulers import LinearDecayLRScheduler
...@@ -102,11 +101,11 @@ def run_train(args): ...@@ -102,11 +101,11 @@ def run_train(args):
num_nodes=args.num_nodes, num_nodes=args.num_nodes,
devices=args.gpus, devices=args.gpus,
accelerator="gpu", accelerator="gpu",
strategy="ddp", strategy="ddp_find_unused_parameters_true",
precision=args.precision, precision=args.precision,
accumulate_grad_batches=args.accumulate_grad_batches, accumulate_grad_batches=args.accumulate_grad_batches,
gradient_clip_val=args.clip_norm, gradient_clip_val=args.clip_norm,
replace_sampler_ddp=False, use_distributed_sampler=False,
callbacks=callbacks, callbacks=callbacks,
reload_dataloaders_every_n_epochs=1, reload_dataloaders_every_n_epochs=1,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment