"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "f72a48dd79bdb93afb75db6cfe3e48832e5ded0a"
Commit c6a52355 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Simplify train step in Conformer RNN-T LibriSpeech recipe (#2981)

Summary:
In the Conformer RNN-T LibriSpeech recipe, there's no need to perform manual optimization. This PR modifies the recipe to use automatic optimization instead.

Pull Request resolved: https://github.com/pytorch/audio/pull/2981

Reviewed By: mthrok

Differential Revision: D42507228

Pulled By: hwangjeff

fbshipit-source-id: 9712add951eba356e39f7e8c8dc3bf584ba48309
parent bb077284
...@@ -95,8 +95,6 @@ class ConformerRNNTModule(LightningModule): ...@@ -95,8 +95,6 @@ class ConformerRNNTModule(LightningModule):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96)
self.automatic_optimization = False
def _step(self, batch, _, step_type): def _step(self, batch, _, step_type):
if batch is None: if batch is None:
return None return None
...@@ -145,25 +143,13 @@ class ConformerRNNTModule(LightningModule): ...@@ -145,25 +143,13 @@ class ConformerRNNTModule(LightningModule):
- Update parameters on each GPU. - Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields. variable-length sequential data yield.
""" """
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train") loss = self._step(batch, batch_idx, "train")
batch_size = batch.features.size(0) batch_size = batch.features.size(0)
batch_sizes = self.all_gather(batch_size) batch_sizes = self.all_gather(batch_size)
self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True) self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.0)
opt.step()
# step every epoch
sch = self.lr_schedulers()
if self.trainer.is_last_batch:
sch.step()
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
......
...@@ -44,6 +44,7 @@ def run_train(args): ...@@ -44,6 +44,7 @@ def run_train(args):
strategy=DDPPlugin(find_unused_parameters=False), strategy=DDPPlugin(find_unused_parameters=False),
callbacks=callbacks, callbacks=callbacks,
reload_dataloaders_every_n_epochs=1, reload_dataloaders_every_n_epochs=1,
gradient_clip_val=10.0,
) )
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path)) sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment