Commit 4c746dbe authored by Michael Snower's avatar Michael Snower Committed by Facebook GitHub Bot
Browse files

d2go profiler registry

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/201

Adds profiler registry.

Reviewed By: Maninae, wat3rBro

Differential Revision: D34725664

fbshipit-source-id: 52cb99b618e5ba5f9bd8d272d4dcaa770d66983a
parent 7a1213a0
...@@ -38,7 +38,7 @@ from d2go.modeling.quantization import ( ...@@ -38,7 +38,7 @@ from d2go.modeling.quantization import (
QATHook, QATHook,
) )
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import add_flop_printing_hook from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.get_default_cfg import get_default_cfg from d2go.utils.get_default_cfg import get_default_cfg
from d2go.utils.helper import TensorboardXWriter, D2Trainer from d2go.utils.helper import TensorboardXWriter, D2Trainer
from d2go.utils.misc import get_tensorboard_log_dir from d2go.utils.misc import get_tensorboard_log_dir
...@@ -176,6 +176,9 @@ class BaseRunner(object): ...@@ -176,6 +176,9 @@ class BaseRunner(object):
) # upgrade from D2's CfgNode to D2Go's CfgNode ) # upgrade from D2's CfgNode to D2Go's CfgNode
cfg.SOLVER.AUTO_SCALING_METHODS = ["default_scale_d2_configs"] cfg.SOLVER.AUTO_SCALING_METHODS = ["default_scale_d2_configs"]
cfg.PROFILERS = ["default_flop_counter"]
return cfg return cfg
def build_model(self, cfg, eval_only=False): def build_model(self, cfg, eval_only=False):
...@@ -314,7 +317,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -314,7 +317,7 @@ class Detectron2GoRunner(BaseRunner):
dataset_name, dataset_name,
) )
add_flop_printing_hook(model, cfg.OUTPUT_DIR) attach_profilers(cfg, model)
results = OrderedDict() results = OrderedDict()
results[model_tag] = OrderedDict() results[model_tag] = OrderedDict()
...@@ -419,7 +422,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -419,7 +422,7 @@ class Detectron2GoRunner(BaseRunner):
def do_train(self, cfg, model, resume): def do_train(self, cfg, model, resume):
# Note that flops at the beginning of training is often inaccurate, # Note that flops at the beginning of training is often inaccurate,
# if a model has input-dependent logic # if a model has input-dependent logic
add_flop_printing_hook(model, cfg.OUTPUT_DIR) attach_profilers(cfg, model)
optimizer = self.build_optimizer(cfg, model) optimizer = self.build_optimizer(cfg, model)
scheduler = self.build_lr_scheduler(cfg, optimizer) scheduler = self.build_lr_scheduler(cfg, optimizer)
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# import to make sure Registry works
from . import flop_calculator # noqa
...@@ -12,9 +12,13 @@ import torch ...@@ -12,9 +12,13 @@ import torch
from d2go.utils.helper import run_once from d2go.utils.helper import run_once
from detectron2.utils.analysis import FlopCountAnalysis from detectron2.utils.analysis import FlopCountAnalysis
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from detectron2.utils.registry import Registry
from fvcore.nn import flop_count_table, flop_count_str from fvcore.nn import flop_count_table, flop_count_str
PROFILER_REGISTRY = Registry("PROFILER")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,6 +45,10 @@ def dump_flops_info(model, inputs, output_dir, use_eval_mode=True): ...@@ -41,6 +45,10 @@ def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
logger.info("Failed to deepcopy the model and skip FlopsEstimation.") logger.info("Failed to deepcopy the model and skip FlopsEstimation.")
return return
# delete other forward_pre_hooks so they are not simultaneously called
for k in model._forward_pre_hooks:
del model._forward_pre_hooks[k]
if use_eval_mode: if use_eval_mode:
model.eval() model.eval()
inputs = copy.deepcopy(inputs) inputs = copy.deepcopy(inputs)
...@@ -111,6 +119,12 @@ def add_flop_printing_hook( ...@@ -111,6 +119,12 @@ def add_flop_printing_hook(
handle = model.register_forward_pre_hook(hook) handle = model.register_forward_pre_hook(hook)
@PROFILER_REGISTRY.register()
def default_flop_counter(model, cfg):
return add_flop_printing_hook(model, cfg.OUTPUT_DIR)
# NOTE: the logging can be too long and messsy when printing flops multiple # NOTE: the logging can be too long and messsy when printing flops multiple
# times, especially when running eval during training, thus using `run_once` # times, especially when running eval during training, thus using `run_once`
# to limit it. `dump_flops_info` can log flops more concisely. # to limit it. `dump_flops_info` can log flops more concisely.
...@@ -156,3 +170,12 @@ def add_print_flops_callback(cfg, model, disable_after_callback=True): ...@@ -156,3 +170,12 @@ def add_print_flops_callback(cfg, model, disable_after_callback=True):
logger.info("Added callback to log flops info after the first inference") logger.info("Added callback to log flops info after the first inference")
fest.set_enable(True) fest.set_enable(True)
return fest return fest
def attach_profiler(profiler_name):
return PROFILER_REGISTRY.get(profiler_name)
def attach_profilers(cfg, model):
for profiler in cfg.PROFILERS:
attach_profiler(profiler)(model, cfg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment