Commit c687fb83 authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

Person segmentation using torch lightning

Summary:
Add option to train Person Instance Segmentation using lightning instead of D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8).
This is needed because we want to try PIS with SuperNet and our SuperNet based training is implemented in d2go lightning task

Reviewed By: zhanghang1989

Differential Revision: D33281437

fbshipit-source-id: e1b6567f3c77ce51240fb50d81350bc97735713a
parent 9d649b1e
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import pytorch_lightning as pl # type: ignore
from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
def get_lt_trainer(output_dir: str, cfg):
checkpoint_callback = ModelCheckpoint(dirpath=output_dir, save_last=True)
return pl.Trainer(
max_epochs=10 ** 8,
max_steps=cfg.SOLVER.MAX_ITER,
val_check_interval=cfg.TEST.EVAL_PERIOD
if cfg.TEST.EVAL_PERIOD > 0
else cfg.SOLVER.MAX_ITER,
callbacks=[checkpoint_callback],
logger=None,
)
def lt_train(task, trainer):
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
def lt_test(task, trainer):
with EventStorage() as storage:
task.storage = storage
trainer.test(task)
return task.eval_res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment