Commit 00208026 authored by Ajinkya Deogade's avatar Ajinkya Deogade Committed by Facebook GitHub Bot
Browse files

Runner: create a separate buck target

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/559

Create modular TARGETS for files inside `runner`.

Reviewed By: wat3rBro

Differential Revision: D45854271

fbshipit-source-id: a15ef475f72685ae8c3c73e0a83cf136a7285d3e
parent 03156ce7
...@@ -8,8 +8,9 @@ from d2go.data.build import ( ...@@ -8,8 +8,9 @@ from d2go.data.build import (
add_weighted_training_sampler_default_configs, add_weighted_training_sampler_default_configs,
) )
from d2go.data.config import add_d2go_data_default_configs from d2go.data.config import add_d2go_data_default_configs
from d2go.modeling import ema, kmeans_anchors
from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs
from d2go.modeling.ema import add_model_ema_configs
from d2go.modeling.kmeans_anchors import add_kmeans_anchors_cfg
from d2go.modeling.meta_arch.fcos import add_fcos_configs from d2go.modeling.meta_arch.fcos import add_fcos_configs
from d2go.modeling.model_freezing_utils import add_model_freezing_configs from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.subclass import add_subclass_configs from d2go.modeling.subclass import add_subclass_configs
...@@ -63,13 +64,13 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -63,13 +64,13 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# _C.MODEL.FROZEN_LAYER_REG_EXP # _C.MODEL.FROZEN_LAYER_REG_EXP
add_model_freezing_configs(_C) add_model_freezing_configs(_C)
# _C.MODEL other models # _C.MODEL other models
ema.add_model_ema_configs(_C) add_model_ema_configs(_C)
# _C.D2GO_DATA... # _C.D2GO_DATA...
add_d2go_data_default_configs(_C) add_d2go_data_default_configs(_C)
# _C.TENSORBOARD... # _C.TENSORBOARD...
add_tensorboard_default_configs(_C) add_tensorboard_default_configs(_C)
# _C.MODEL.KMEANS... # _C.MODEL.KMEANS...
kmeans_anchors.add_kmeans_anchors_cfg(_C) add_kmeans_anchors_cfg(_C)
# _C.QUANTIZATION # _C.QUANTIZATION
add_quantization_default_configs(_C) add_quantization_default_configs(_C)
# _C.DATASETS.TRAIN_REPEAT_FACTOR # _C.DATASETS.TRAIN_REPEAT_FACTOR
......
...@@ -8,7 +8,6 @@ from collections import OrderedDict ...@@ -8,7 +8,6 @@ from collections import OrderedDict
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, Type, Union from typing import List, Optional, Type, Union
import d2go.utils.abnormal_checker as abnormal_checker
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
from d2go.checkpoint.api import is_distributed_checkpoint from d2go.checkpoint.api import is_distributed_checkpoint
...@@ -25,8 +24,9 @@ from d2go.data.utils import ( ...@@ -25,8 +24,9 @@ from d2go.data.utils import (
update_cfg_if_using_adhoc_dataset, update_cfg_if_using_adhoc_dataset,
) )
from d2go.evaluation.evaluator import inference_on_dataset from d2go.evaluation.evaluator import inference_on_dataset
from d2go.modeling import ema, kmeans_anchors from d2go.modeling import ema
from d2go.modeling.api import build_d2go_model from d2go.modeling.api import build_d2go_model
from d2go.modeling.kmeans_anchors import compute_kmeans_anchors_hook
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
from d2go.optimizer.build import build_optimizer_mapper from d2go.optimizer.build import build_optimizer_mapper
from d2go.quantization.modeling import QATHook, setup_qat_model from d2go.quantization.modeling import QATHook, setup_qat_model
...@@ -43,6 +43,11 @@ from d2go.runner.training_hooks import ( ...@@ -43,6 +43,11 @@ from d2go.runner.training_hooks import (
) )
from d2go.trainer.fsdp import get_grad_scaler from d2go.trainer.fsdp import get_grad_scaler
from d2go.trainer.helper import parse_precision_from_string from d2go.trainer.helper import parse_precision_from_string
from d2go.utils.abnormal_checker import (
AbnormalLossChecker,
AbnormalLossCheckerWrapper,
get_writers,
)
from d2go.utils.flop_calculator import attach_profilers from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.gpu_memory_profiler import attach_oom_logger from d2go.utils.gpu_memory_profiler import attach_oom_logger
from d2go.utils.helper import D2Trainer, TensorboardXWriter from d2go.utils.helper import D2Trainer, TensorboardXWriter
...@@ -525,7 +530,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -525,7 +530,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
lambda: self.do_test(cfg, model, train_iter=trainer.iter), lambda: self.do_test(cfg, model, train_iter=trainer.iter),
eval_after_train=False, # done by a separate do_test call in tools/train_net.py eval_after_train=False, # done by a separate do_test call in tools/train_net.py
), ),
kmeans_anchors.compute_kmeans_anchors_hook(self, cfg), compute_kmeans_anchors_hook(self, cfg),
self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None, self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None,
] ]
...@@ -577,9 +582,9 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -577,9 +582,9 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
return model return model
tbx_writer = self.get_tbx_writer(cfg) tbx_writer = self.get_tbx_writer(cfg)
writers = abnormal_checker.get_writers(cfg, tbx_writer) writers = get_writers(cfg, tbx_writer)
checker = abnormal_checker.AbnormalLossChecker(start_iter, writers) checker = AbnormalLossChecker(start_iter, writers)
ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker) ret = AbnormalLossCheckerWrapper(model, checker)
return ret return ret
if cfg.SOLVER.AMP.ENABLED: if cfg.SOLVER.AMP.ENABLED:
......
...@@ -23,7 +23,9 @@ from d2go.distributed import ( ...@@ -23,7 +23,9 @@ from d2go.distributed import (
get_local_rank, get_local_rank,
get_num_processes_per_machine, get_num_processes_per_machine,
) )
from d2go.runner import BaseRunner, import_runner, RunnerV2Mixin from d2go.runner import import_runner
from d2go.runner.api import RunnerV2Mixin
from d2go.runner.default_runner import BaseRunner
from d2go.runner.lightning_task import DefaultTask from d2go.runner.lightning_task import DefaultTask
from d2go.utils.helper import run_once from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment from d2go.utils.launch_environment import get_launch_environment
......
...@@ -10,7 +10,7 @@ from typing import Optional ...@@ -10,7 +10,7 @@ from typing import Optional
import d2go.data.transforms.box_utils as bu import d2go.data.transforms.box_utils as bu
import torch import torch
from d2go.export.exporter import convert_and_export_predictor from d2go.export.exporter import convert_and_export_predictor
from d2go.runner import GeneralizedRCNNRunner from d2go.runner.default_runner import GeneralizedRCNNRunner
from d2go.utils.testing.data_loader_helper import ( from d2go.utils.testing.data_loader_helper import (
create_detection_data_loader_on_toy_dataset, create_detection_data_loader_on_toy_dataset,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment