Commit 914054ac authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

enable overloading get_data_loader_vis_wrapper

Summary:
Currently the lightning task rely on the default runner for the vis wrapper logic. This does not allow to overload the get_data_loader_vis_wrapper is subclasses of the lightning task class.
This diff fixes this issue and properly take the vis wrapper given by the overloaded get_data_loader_vis_wrapper functions in the runner.

Reviewed By: zhanghang1989

Differential Revision: D33190410

fbshipit-source-id: 48cb3a8fa4b11df41d025d115d21002991549ced
parent 0e3323be
......@@ -7,9 +7,11 @@ from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type
import detectron2.utils.comm as comm
import pytorch_lightning as pl
import torch
from d2go.config import CfgNode
from d2go.data.build import build_d2go_train_loader
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.utils import (
update_cfg_if_using_adhoc_dataset,
......@@ -171,7 +173,6 @@ class DefaultTask(pl.LightningModule):
losses = sum(loss_dict.values())
loss_dict["total_loss"] = losses
self.storage.step()
self.log_dict(loss_dict, prog_bar=True)
return losses
......@@ -368,13 +369,28 @@ class DefaultTask(pl.LightningModule):
return Detectron2GoRunner.get_visualization_evaluator()
@staticmethod
def build_detection_train_loader(cfg, *args, mapper=None, **kwargs):
return Detectron2GoRunner.build_detection_train_loader(cfg, *args, **kwargs)
def get_data_loader_vis_wrapper():
return Detectron2GoRunner.get_data_loader_vis_wrapper()
@classmethod
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
mapper = mapper or cls.get_mapper(cfg, is_train=True)
data_loader = build_d2go_train_loader(cfg, mapper)
return cls._attach_visualizer_to_data_loader(cfg, data_loader)
@staticmethod
def build_detection_test_loader(cfg, dataset_name, mapper=None):
return Detectron2GoRunner.build_detection_test_loader(cfg, dataset_name, mapper)
@classmethod
def _attach_visualizer_to_data_loader(cls, cfg, data_loader):
if comm.is_main_process():
data_loader_type = cls.get_data_loader_vis_wrapper()
if data_loader_type is not None:
tbx_writer = Detectron2GoRunner.get_tbx_writer(cfg)
data_loader = data_loader_type(cfg, tbx_writer, data_loader)
return data_loader
# ---------------------------------------------------------------------------
# Hooks
# ---------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment