Unverified Commit 5e893d6f authored by Pingchuan Ma's avatar Pingchuan Ma Committed by GitHub
Browse files

Simplify trainining step in av-asr recipe (#3598)

* Simplify trainining step in av-asr recipe

* Run pre-commit
parent 3e1d8f3c
...@@ -84,8 +84,6 @@ class ConformerRNNTModule(LightningModule): ...@@ -84,8 +84,6 @@ class ConformerRNNTModule(LightningModule):
betas=(0.9, 0.98), betas=(0.9, 0.98),
) )
self.automatic_optimization = False
def _step(self, batch, _, step_type): def _step(self, batch, _, step_type):
if batch is None: if batch is None:
return None return None
...@@ -123,20 +121,10 @@ class ConformerRNNTModule(LightningModule): ...@@ -123,20 +121,10 @@ class ConformerRNNTModule(LightningModule):
return post_process_hypos(hypotheses, self.sp_model)[0][0] return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train") loss = self._step(batch, batch_idx, "train")
batch_size = batch.inputs.size(0) batch_size = batch.inputs.size(0)
batch_sizes = self.all_gather(batch_size) batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32)) self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss return loss
......
...@@ -80,8 +80,6 @@ class AVConformerRNNTModule(LightningModule): ...@@ -80,8 +80,6 @@ class AVConformerRNNTModule(LightningModule):
betas=(0.9, 0.98), betas=(0.9, 0.98),
) )
self.automatic_optimization = False
def _step(self, batch, _, step_type): def _step(self, batch, _, step_type):
if batch is None: if batch is None:
return None return None
...@@ -128,20 +126,10 @@ class AVConformerRNNTModule(LightningModule): ...@@ -128,20 +126,10 @@ class AVConformerRNNTModule(LightningModule):
return post_process_hypos(hypotheses, self.sp_model)[0][0] return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train") loss = self._step(batch, batch_idx, "train")
batch_size = batch.videos.size(0) batch_size = batch.videos.size(0)
batch_sizes = self.all_gather(batch_size) batch_sizes = self.all_gather(batch_size)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()
sch = self.lr_schedulers()
sch.step()
self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32)) self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))
return loss return loss
......
...@@ -36,6 +36,7 @@ def get_trainer(args): ...@@ -36,6 +36,7 @@ def get_trainer(args):
strategy=DDPStrategy(find_unused_parameters=False), strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks, callbacks=callbacks,
reload_dataloaders_every_n_epochs=1, reload_dataloaders_every_n_epochs=1,
gradient_clip_val=10.0,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment