Commit 5509a138 authored by Yuxin Wu's avatar Yuxin Wu Committed by Facebook GitHub Bot
Browse files

enable flop printing & logging at the beginning of train & test

Reviewed By: zhanghang1989

Differential Revision: D29379832

fbshipit-source-id: 9283a8796a1dbee81b51611407c22f7d5a2069dc
parent 1894f8a3
...@@ -39,7 +39,7 @@ from d2go.modeling.quantization import ( ...@@ -39,7 +39,7 @@ from d2go.modeling.quantization import (
setup_qat_model, setup_qat_model,
) )
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import add_print_flops_callback from d2go.utils.flop_calculator import add_flop_printing_hook
from d2go.utils.get_default_cfg import get_default_cfg from d2go.utils.get_default_cfg import get_default_cfg
from d2go.utils.helper import TensorboardXWriter, D2Trainer from d2go.utils.helper import TensorboardXWriter, D2Trainer
from d2go.utils.misc import get_tensorboard_log_dir from d2go.utils.misc import get_tensorboard_log_dir
...@@ -307,7 +307,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -307,7 +307,7 @@ class Detectron2GoRunner(BaseRunner):
dataset_name, dataset_name,
) )
add_print_flops_callback(cfg, model, disable_after_callback=True) add_flop_printing_hook(model, cfg.OUTPUT_DIR)
results = OrderedDict() results = OrderedDict()
results[model_tag] = OrderedDict() results[model_tag] = OrderedDict()
...@@ -406,7 +406,9 @@ class Detectron2GoRunner(BaseRunner): ...@@ -406,7 +406,9 @@ class Detectron2GoRunner(BaseRunner):
return results return results
def do_train(self, cfg, model, resume): def do_train(self, cfg, model, resume):
add_print_flops_callback(cfg, model, disable_after_callback=True) # Note that flops at the beginning of training is often inaccurate,
# if a model has input-dependent logic
add_flop_printing_hook(model, cfg.OUTPUT_DIR)
optimizer = self.build_optimizer(cfg, model) optimizer = self.build_optimizer(cfg, model)
scheduler = self.build_lr_scheduler(cfg, optimizer) scheduler = self.build_lr_scheduler(cfg, optimizer)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy import copy
import torch
import os import os
import logging import logging
...@@ -16,23 +17,32 @@ from d2go.utils.helper import run_once ...@@ -16,23 +17,32 @@ from d2go.utils.helper import run_once
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def dump_flops_info(model, inputs, output_dir): @torch.no_grad()
def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
""" """
Dump flops information about model, using the given model inputs. Dump flops information about model, using the given model inputs.
Information are dumped to output_dir using various flop counting tools Information are dumped to output_dir using various flop counting tools
in different formats. Only a simple table is printed to terminal. in different formats. Only a simple table is printed to terminal.
Args:
inputs: a tuple of positional arguments used to call model with.
use_eval_mode: turn the model into eval mode for flop counting. Otherwise,
will use the original mode. It's recommended to use eval mode, because
training mode typically follows a different codepath.
""" """
if not comm.is_main_process(): if not comm.is_main_process():
return return
logger.info("Evaluating model's number of parameters and FLOPS") logger.info("Evaluating model's number of parameters and FLOPS")
model = copy.deepcopy(model) model = copy.deepcopy(model)
model.eval() if use_eval_mode:
model.eval()
inputs = copy.deepcopy(inputs)
# 1. using mobile_cv flop counter # 1. using mobile_cv flop counter
try: try:
fest = flops_utils.FlopsEstimation(model) fest = flops_utils.FlopsEstimation(model)
with fest.enable(): with fest.enable():
model(inputs) model(*inputs)
fest.add_flops_info() fest.add_flops_info()
model_str = str(model) model_str = str(model)
output_file = os.path.join(output_dir, "flops_str_mobilecv.txt") output_file = os.path.join(output_dir, "flops_str_mobilecv.txt")
...@@ -59,18 +69,38 @@ def dump_flops_info(model, inputs, output_dir): ...@@ -59,18 +69,38 @@ def dump_flops_info(model, inputs, output_dir):
output_file = os.path.join(output_dir, "flops_table_fvcore.txt") output_file = os.path.join(output_dir, "flops_table_fvcore.txt")
with PathManager.open(output_file, "w") as f: with PathManager.open(output_file, "w") as f:
f.write(flops_table) f.write(flops_table)
logger.info(f"Flops table written to {output_file}") logger.info(f"Flops table (full version) written to {output_file}")
# 2.3: print a table with a shallow depth # 2.3: print a table with a shallow depth
flops_table = flop_count_table(flops, max_depth=3) flops_table = flop_count_table(flops, max_depth=3)
logger.info("Flops table:\n" + flops_table) logger.info("Flops table:\n" + flops_table)
except Exception: except Exception:
logger.exception("Failed to estimate flops using detectron2's FlopCountAnalysis") logger.exception("Failed to estimate flops using detectron2's FlopCountAnalysis")
return flops
def add_flop_printing_hook(
model,
output_dir: str,
):
"""
Add a pytorch module forward hook that will print/save flops of the whole model
at the first time the model is called.
Args:
output_dir: directory to save more detailed flop info
"""
def hook(module, input):
handle.remove()
dump_flops_info(module, input, output_dir)
return input
handle = model.register_forward_pre_hook(hook)
# NOTE: the logging can be too long and messsy when printing flops multiple # NOTE: the logging can be too long and messsy when printing flops multiple
# times, especially when running eval during training, thus using `run_once` # times, especially when running eval during training, thus using `run_once`
# to limit it. TODO: log the flops more concisely. # to limit it. `dump_flops_info` can log flops more concisely.
@run_once() @run_once()
def add_print_flops_callback(cfg, model, disable_after_callback=True): def add_print_flops_callback(cfg, model, disable_after_callback=True):
def _print_flops_callback(self, model, model_data): def _print_flops_callback(self, model, model_data):
......
...@@ -14,7 +14,7 @@ class TestFlopCount(RCNNBaseTestCases.TemplateTestCase): ...@@ -14,7 +14,7 @@ class TestFlopCount(RCNNBaseTestCases.TemplateTestCase):
size_divisibility = max(self.test_model.backbone.size_divisibility, 10) size_divisibility = max(self.test_model.backbone.size_divisibility, 10)
h, w = size_divisibility, size_divisibility * 2 h, w = size_divisibility, size_divisibility * 2
with create_fake_detection_data_loader(h, w, is_train=False) as data_loader: with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
inputs = next(iter(data_loader)) inputs = (next(iter(data_loader)),)
with tempfile.TemporaryDirectory(prefix="d2go_test") as output_dir: with tempfile.TemporaryDirectory(prefix="d2go_test") as output_dir:
dump_flops_info(self.test_model, inputs, output_dir) dump_flops_info(self.test_model, inputs, output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment