"tests/vscode:/vscode.git/clone" did not exist on "2de9e2df368241cf13f859cf51514cea4e53aed5"
Commit 536e9d25 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

add dataset visualization

Summary: Add dataset visualization so that we could visualize test results in Tensorboard.

Reviewed By: zhanghang1989

Differential Revision: D28457363

fbshipit-source-id: 4c2fd9dce349c6fb9e1cec51c9138cf0abb45d7b
parent fdd64119
......@@ -5,7 +5,7 @@ import logging
import os
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type
import pytorch_lightning as pl
import torch
......@@ -18,20 +18,26 @@ from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.modeling.model_freezing_utils import (
set_requires_grad,
)
from d2go.modeling.quantization import (
default_prepare_for_quant,
default_prepare_for_quant_convert,
)
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import (
Detectron2GoRunner,
GeneralizedRCNNRunner,
_get_tbx_writer,
)
from d2go.setup import setup_after_lightning_launch
from d2go.utils.ema_state import EMAState
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.utils.misc import get_tensorboard_log_dir
from d2go.utils.visualization import VisualizationEvaluator
from detectron2.modeling import build_model
from detectron2.solver import (
build_lr_scheduler as d2_build_lr_scheduler,
build_optimizer as d2_build_optimizer,
)
from pytorch_lightning.utilities import rank_zero_info
from d2go.modeling.quantization import default_prepare_for_quant, default_prepare_for_quant_convert
from pytorch_lightning.utilities import rank_zero_only, rank_zero_info
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
......@@ -41,12 +47,12 @@ logger = logging.getLogger(__name__)
def _is_lightning_checkpoint(checkpoint: Dict[str, Any]) -> bool:
""" Returns true if we believe this checkpoint to be a Lightning checkpoint. """
"""Returns true if we believe this checkpoint to be a Lightning checkpoint."""
return _STATE_DICT_KEY in checkpoint
def _is_d2go_checkpoint(checkpoint: Dict[str, Any]) -> bool:
""" Returns true if we believe this to be a D2Go checkpoint. """
"""Returns true if we believe this to be a D2Go checkpoint."""
d2_go_keys = [_OLD_STATE_DICT_KEY, "optimizer", "scheduler", "iteration"]
for key in d2_go_keys:
if key not in checkpoint:
......@@ -55,7 +61,7 @@ def _is_d2go_checkpoint(checkpoint: Dict[str, Any]) -> bool:
def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None:
""" Converst a D2Go Checkpoint to Lightning in-place by renaming keys."""
"""Converst a D2Go Checkpoint to Lightning in-place by renaming keys."""
prefix = "model" # based on DefaultTask.model.
old_keys = list(d2_checkpoint[_OLD_STATE_DICT_KEY])
for key in old_keys:
......@@ -237,6 +243,29 @@ class DefaultTask(pl.LightningModule):
dataset_name,
)
@rank_zero_only
def _setup_visualization_evaluator(
evaluator,
dataset_name: str,
model_tag: ModelTag,
) -> None:
logger.info("Adding visualization evaluator ...")
mapper = self.get_mapper(self.cfg, is_train=False)
vis_eval_type = self.get_visualization_evaluator()
# TODO: replace tbx_writter with Lightning's self.logger.experiment
tbx_writter = _get_tbx_writer(get_tensorboard_log_dir(self.cfg.OUTPUT_DIR))
if vis_eval_type is not None:
evaluator._evaluators.append(
vis_eval_type(
self.cfg,
tbx_writter,
mapper,
dataset_name,
train_iter=self.trainer.global_step,
tag_postfix=model_tag,
)
)
for tag, dataset_evaluators in self.dataset_evaluators.items():
dataset_evaluators.clear()
assert self.cfg.OUTPUT_DIR, "Expect output_dir to be specified in config"
......@@ -250,7 +279,7 @@ class DefaultTask(pl.LightningModule):
)
evaluator.reset()
dataset_evaluators.append(evaluator)
# TODO: add visualization evaluator
_setup_visualization_evaluator(evaluator, dataset_name, tag)
def _evaluation_dataloader(self):
# TODO: Support subsample n images
......@@ -317,6 +346,10 @@ class DefaultTask(pl.LightningModule):
def get_mapper(cfg, is_train):
return Detectron2GoRunner.get_mapper(cfg, is_train)
@staticmethod
def get_visualization_evaluator() -> Optional[Type[VisualizationEvaluator]]:
return Detectron2GoRunner.get_visualization_evaluator()
@staticmethod
def build_detection_train_loader(cfg, *args, mapper=None, **kwargs):
return Detectron2GoRunner.build_detection_train_loader(cfg, *args, **kwargs)
......@@ -397,7 +430,6 @@ class DefaultTask(pl.LightningModule):
self.model = default_prepare_for_quant(self.cfg, self.model)
return self
def prepare_for_quant_convert(self) -> pl.LightningModule:
if hasattr(self.model, "prepare_for_quant_convert"):
self.model = self.model.prepare_for_quant_convert(self.cfg)
......@@ -405,6 +437,7 @@ class DefaultTask(pl.LightningModule):
self.model = default_prepare_for_quant_convert(self.cfg, self.model)
return self
class GeneralizedRCNNTask(DefaultTask):
@classmethod
def get_default_cfg(cls):
......
......@@ -18,7 +18,6 @@ from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment