"git@developer.sourcefind.cn:change/sglang.git" did not exist on "c1815a99b78e1146e8c47020c1959f787cf31b10"
Commit 4c746dbe authored by Michael Snower's avatar Michael Snower Committed by Facebook GitHub Bot
Browse files

d2go profiler registry

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/201

Adds profiler registry.

Reviewed By: Maninae, wat3rBro

Differential Revision: D34725664

fbshipit-source-id: 52cb99b618e5ba5f9bd8d272d4dcaa770d66983a
parent 7a1213a0
......@@ -38,7 +38,7 @@ from d2go.modeling.quantization import (
QATHook,
)
from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import add_flop_printing_hook
from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.get_default_cfg import get_default_cfg
from d2go.utils.helper import TensorboardXWriter, D2Trainer
from d2go.utils.misc import get_tensorboard_log_dir
......@@ -176,6 +176,9 @@ class BaseRunner(object):
) # upgrade from D2's CfgNode to D2Go's CfgNode
cfg.SOLVER.AUTO_SCALING_METHODS = ["default_scale_d2_configs"]
cfg.PROFILERS = ["default_flop_counter"]
return cfg
def build_model(self, cfg, eval_only=False):
......@@ -314,7 +317,7 @@ class Detectron2GoRunner(BaseRunner):
dataset_name,
)
add_flop_printing_hook(model, cfg.OUTPUT_DIR)
attach_profilers(cfg, model)
results = OrderedDict()
results[model_tag] = OrderedDict()
......@@ -419,7 +422,7 @@ class Detectron2GoRunner(BaseRunner):
def do_train(self, cfg, model, resume):
# Note that flops at the beginning of training is often inaccurate,
# if a model has input-dependent logic
add_flop_printing_hook(model, cfg.OUTPUT_DIR)
attach_profilers(cfg, model)
optimizer = self.build_optimizer(cfg, model)
scheduler = self.build_lr_scheduler(cfg, optimizer)
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# import to make sure Registry works
from . import flop_calculator # noqa
......@@ -12,9 +12,13 @@ import torch
from d2go.utils.helper import run_once
from detectron2.utils.analysis import FlopCountAnalysis
from detectron2.utils.file_io import PathManager
from detectron2.utils.registry import Registry
from fvcore.nn import flop_count_table, flop_count_str
PROFILER_REGISTRY = Registry("PROFILER")
logger = logging.getLogger(__name__)
......@@ -41,6 +45,10 @@ def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
logger.info("Failed to deepcopy the model and skip FlopsEstimation.")
return
# delete other forward_pre_hooks so they are not simultaneously called
for k in model._forward_pre_hooks:
del model._forward_pre_hooks[k]
if use_eval_mode:
model.eval()
inputs = copy.deepcopy(inputs)
......@@ -111,6 +119,12 @@ def add_flop_printing_hook(
handle = model.register_forward_pre_hook(hook)
@PROFILER_REGISTRY.register()
def default_flop_counter(model, cfg):
return add_flop_printing_hook(model, cfg.OUTPUT_DIR)
# NOTE: the logging can be too long and messsy when printing flops multiple
# times, especially when running eval during training, thus using `run_once`
# to limit it. `dump_flops_info` can log flops more concisely.
......@@ -156,3 +170,12 @@ def add_print_flops_callback(cfg, model, disable_after_callback=True):
logger.info("Added callback to log flops info after the first inference")
fest.set_enable(True)
return fest
def attach_profiler(profiler_name):
return PROFILER_REGISTRY.get(profiler_name)
def attach_profilers(cfg, model):
for profiler in cfg.PROFILERS:
attach_profiler(profiler)(model, cfg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment