Unverified Commit 5e893d6f authored by Pingchuan Ma's avatar Pingchuan Ma Committed by GitHub
Browse files

Simplify trainining step in av-asr recipe (#3598)

* Simplify trainining step in av-asr recipe

* Run pre-commit
parent 3e1d8f3c
......@@ -84,8 +84,6 @@ class ConformerRNNTModule(LightningModule):
betas=(0.9, 0.98),
)
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
......@@ -123,20 +121,10 @@ class ConformerRNNTModule(LightningModule):
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.inputs.size(0)
batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
......
......@@ -80,8 +80,6 @@ class AVConformerRNNTModule(LightningModule):
betas=(0.9, 0.98),
)
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
......@@ -128,20 +126,10 @@ class AVConformerRNNTModule(LightningModule):
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.videos.size(0)
batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss
......
......@@ -36,6 +36,7 @@ def get_trainer(args):
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
gradient_clip_val=10.0,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment