Commit b6e244d2 authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

Add support for custom training step via meta_arch

Summary:
Add support in the default lightning task to run a custom training step from Meta Arch if exists.
The goal is to allow custom training step without the need to inherit from the default lightning task class and override it. This will allow us to use a signle lightning task and still allow users to customize the training step. In the long run this will be further encapsulated in modeling hook, making it more modular and compositable with other custom code.

This change is a follow up from discussion in  https://fburl.com/diff/yqlsypys

Reviewed By: wat3rBro

Differential Revision: D33534624

fbshipit-source-id: 560f06da03f218e77ad46832be9d741417882c56
parent c687fb83
......@@ -119,6 +119,11 @@ class DefaultTask(pl.LightningModule):
self.save_hyperparameters()
self.eval_res = None
# Support custom training step in meta arch
if hasattr(self.model, "training_step"):
# activate manual optimization for custom training step
self.automatic_optimization = False
self.ema_state: Optional[EMAState] = None
if cfg.MODEL_EMA.ENABLED:
self.ema_state = EMAState(
......@@ -164,6 +169,12 @@ class DefaultTask(pl.LightningModule):
return task
def training_step(self, batch, batch_idx):
if hasattr(self.model, "training_step"):
self._meta_arch_training_step(batch, batch_idx)
return self._standard_training_step(batch, batch_idx)
def _standard_training_step(self, batch, batch_idx):
loss_dict = self.forward(batch)
losses = sum(loss_dict.values())
loss_dict["total_loss"] = losses
......@@ -171,6 +182,17 @@ class DefaultTask(pl.LightningModule):
self.log_dict(loss_dict, prog_bar=True)
return losses
def _meta_arch_training_step(self, batch, batch_idx):
opt = self.optimizers()
loss_dict = self.model.training_step(
batch, batch_idx, opt, self.manual_backward
)
sch = self.lr_schedulers()
sch.step()
self.storage.step()
self.log_dict(loss_dict, prog_bar=True)
return loss_dict
def test_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None:
self._evaluation_step(batch, batch_idx, dataloader_idx)
......
......@@ -213,3 +213,23 @@ class TestLightningTask(unittest.TestCase):
cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
@tempdir
def test_meta_arch_training_step(self, tmp_dir):
@META_ARCH_REGISTRY.register()
class DetMetaArchForWithTrainingStep(mah.DetMetaArchForTest):
def training_step(self, batch, batch_idx, opt, manual_backward):
assert batch
assert opt
assert manual_backward
return {"total_loss": 0.4}
cfg = self._get_cfg(tmp_dir)
cfg.MODEL.META_ARCHITECTURE = "DetMetaArchForWithTrainingStep"
task = GeneralizedRCNNTask(cfg)
trainer = self._get_trainer(tmp_dir)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment