"docs/api/index.mdx" did not exist on "5f57b0ef4268a6bd9e8043d54c351a608a7e1bca"
Commit 5509a138 authored by Yuxin Wu's avatar Yuxin Wu Committed by Facebook GitHub Bot
Browse files

enable flop printing & logging at the beginning of train & test

Reviewed By: zhanghang1989

Differential Revision: D29379832

fbshipit-source-id: 9283a8796a1dbee81b51611407c22f7d5a2069dc
parent 1894f8a3
......@@ -39,7 +39,7 @@ from d2go.modeling.quantization import (
setup_qat_model,
)
from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import add_print_flops_callback
from d2go.utils.flop_calculator import add_flop_printing_hook
from d2go.utils.get_default_cfg import get_default_cfg
from d2go.utils.helper import TensorboardXWriter, D2Trainer
from d2go.utils.misc import get_tensorboard_log_dir
......@@ -307,7 +307,7 @@ class Detectron2GoRunner(BaseRunner):
dataset_name,
)
add_print_flops_callback(cfg, model, disable_after_callback=True)
add_flop_printing_hook(model, cfg.OUTPUT_DIR)
results = OrderedDict()
results[model_tag] = OrderedDict()
......@@ -406,7 +406,9 @@ class Detectron2GoRunner(BaseRunner):
return results
def do_train(self, cfg, model, resume):
add_print_flops_callback(cfg, model, disable_after_callback=True)
# Note that flops at the beginning of training is often inaccurate,
# if a model has input-dependent logic
add_flop_printing_hook(model, cfg.OUTPUT_DIR)
optimizer = self.build_optimizer(cfg, model)
scheduler = self.build_lr_scheduler(cfg, optimizer)
......
......@@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import torch
import os
import logging
......@@ -16,23 +17,32 @@ from d2go.utils.helper import run_once
logger = logging.getLogger(__name__)
def dump_flops_info(model, inputs, output_dir):
@torch.no_grad()
def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
"""
Dump flops information about model, using the given model inputs.
Information are dumped to output_dir using various flop counting tools
in different formats. Only a simple table is printed to terminal.
Args:
inputs: a tuple of positional arguments used to call model with.
use_eval_mode: turn the model into eval mode for flop counting. Otherwise,
will use the original mode. It's recommended to use eval mode, because
training mode typically follows a different codepath.
"""
if not comm.is_main_process():
return
logger.info("Evaluating model's number of parameters and FLOPS")
model = copy.deepcopy(model)
if use_eval_mode:
model.eval()
inputs = copy.deepcopy(inputs)
# 1. using mobile_cv flop counter
try:
fest = flops_utils.FlopsEstimation(model)
with fest.enable():
model(inputs)
model(*inputs)
fest.add_flops_info()
model_str = str(model)
output_file = os.path.join(output_dir, "flops_str_mobilecv.txt")
......@@ -59,18 +69,38 @@ def dump_flops_info(model, inputs, output_dir):
output_file = os.path.join(output_dir, "flops_table_fvcore.txt")
with PathManager.open(output_file, "w") as f:
f.write(flops_table)
logger.info(f"Flops table written to {output_file}")
logger.info(f"Flops table (full version) written to {output_file}")
# 2.3: print a table with a shallow depth
flops_table = flop_count_table(flops, max_depth=3)
logger.info("Flops table:\n" + flops_table)
except Exception:
logger.exception("Failed to estimate flops using detectron2's FlopCountAnalysis")
return flops
def add_flop_printing_hook(
model,
output_dir: str,
):
"""
Add a pytorch module forward hook that will print/save flops of the whole model
at the first time the model is called.
Args:
output_dir: directory to save more detailed flop info
"""
def hook(module, input):
handle.remove()
dump_flops_info(module, input, output_dir)
return input
handle = model.register_forward_pre_hook(hook)
# NOTE: the logging can be too long and messsy when printing flops multiple
# times, especially when running eval during training, thus using `run_once`
# to limit it. TODO: log the flops more concisely.
# to limit it. `dump_flops_info` can log flops more concisely.
@run_once()
def add_print_flops_callback(cfg, model, disable_after_callback=True):
def _print_flops_callback(self, model, model_data):
......
......@@ -14,7 +14,7 @@ class TestFlopCount(RCNNBaseTestCases.TemplateTestCase):
size_divisibility = max(self.test_model.backbone.size_divisibility, 10)
h, w = size_divisibility, size_divisibility * 2
with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
inputs = next(iter(data_loader))
inputs = (next(iter(data_loader)),)
with tempfile.TemporaryDirectory(prefix="d2go_test") as output_dir:
dump_flops_info(self.test_model, inputs, output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment