Commit 914054ac authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

enable overloading get_data_loader_vis_wrapper

Summary:
Currently the lightning task rely on the default runner for the vis wrapper logic. This does not allow to overload the get_data_loader_vis_wrapper is subclasses of the lightning task class.
This diff fixes this issue and properly take the vis wrapper given by the overloaded get_data_loader_vis_wrapper functions in the runner.

Reviewed By: zhanghang1989

Differential Revision: D33190410

fbshipit-source-id: 48cb3a8fa4b11df41d025d115d21002991549ced
parent 0e3323be
...@@ -7,9 +7,11 @@ from copy import deepcopy ...@@ -7,9 +7,11 @@ from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
import detectron2.utils.comm as comm
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.data.build import build_d2go_train_loader
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.utils import ( from d2go.data.utils import (
update_cfg_if_using_adhoc_dataset, update_cfg_if_using_adhoc_dataset,
...@@ -171,7 +173,6 @@ class DefaultTask(pl.LightningModule): ...@@ -171,7 +173,6 @@ class DefaultTask(pl.LightningModule):
losses = sum(loss_dict.values()) losses = sum(loss_dict.values())
loss_dict["total_loss"] = losses loss_dict["total_loss"] = losses
self.storage.step() self.storage.step()
self.log_dict(loss_dict, prog_bar=True) self.log_dict(loss_dict, prog_bar=True)
return losses return losses
...@@ -368,13 +369,28 @@ class DefaultTask(pl.LightningModule): ...@@ -368,13 +369,28 @@ class DefaultTask(pl.LightningModule):
return Detectron2GoRunner.get_visualization_evaluator() return Detectron2GoRunner.get_visualization_evaluator()
@staticmethod @staticmethod
def build_detection_train_loader(cfg, *args, mapper=None, **kwargs): def get_data_loader_vis_wrapper():
return Detectron2GoRunner.build_detection_train_loader(cfg, *args, **kwargs) return Detectron2GoRunner.get_data_loader_vis_wrapper()
@classmethod
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
mapper = mapper or cls.get_mapper(cfg, is_train=True)
data_loader = build_d2go_train_loader(cfg, mapper)
return cls._attach_visualizer_to_data_loader(cfg, data_loader)
@staticmethod @staticmethod
def build_detection_test_loader(cfg, dataset_name, mapper=None): def build_detection_test_loader(cfg, dataset_name, mapper=None):
return Detectron2GoRunner.build_detection_test_loader(cfg, dataset_name, mapper) return Detectron2GoRunner.build_detection_test_loader(cfg, dataset_name, mapper)
@classmethod
def _attach_visualizer_to_data_loader(cls, cfg, data_loader):
if comm.is_main_process():
data_loader_type = cls.get_data_loader_vis_wrapper()
if data_loader_type is not None:
tbx_writer = Detectron2GoRunner.get_tbx_writer(cfg)
data_loader = data_loader_type(cfg, tbx_writer, data_loader)
return data_loader
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Hooks # Hooks
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment