Commit bc9d5070 authored by Yuxin Wu's avatar Yuxin Wu Committed by Facebook GitHub Bot
Browse files

additional flop counting using fvcore's flop counter

Summary:
1. save 3 versions of flop count, using both mobile_cv's flop counter and fvcore's flop counter
2. print only a simple short table in terminal, but save others to files

The `print_flops` function seems not used anywhere so this diff just replaced it.

TODO: enable this feature automatically for train/eval workflows in the next diff

Reviewed By: zhanghang1989

Differential Revision: D29182412

fbshipit-source-id: bfa1dfad41b99fcda06b96c4732237b5e753f1bb
parent 54b352d9
...@@ -2,9 +2,13 @@ ...@@ -2,9 +2,13 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy import copy
import os
import logging import logging
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from detectron2.utils.file_io import PathManager
from detectron2.utils.analysis import FlopCountAnalysis
from fvcore.nn import flop_count_table, flop_count_str
import mobile_cv.lut.lib.pt.flops_utils as flops_utils import mobile_cv.lut.lib.pt.flops_utils as flops_utils
from d2go.utils.helper import run_once from d2go.utils.helper import run_once
...@@ -12,17 +16,56 @@ from d2go.utils.helper import run_once ...@@ -12,17 +16,56 @@ from d2go.utils.helper import run_once
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def print_flops(model, first_batch): def dump_flops_info(model, inputs, output_dir):
"""
Dump flops information about model, using the given model inputs.
Information are dumped to output_dir using various flop counting tools
in different formats. Only a simple table is printed to terminal.
"""
if not comm.is_main_process():
return
logger.info("Evaluating model's number of parameters and FLOPS") logger.info("Evaluating model's number of parameters and FLOPS")
model_flops = copy.deepcopy(model) model = copy.deepcopy(model)
model_flops.eval() model.eval()
fest = flops_utils.FlopsEstimation(model_flops)
with fest.enable(): # 1. using mobile_cv flop counter
model_flops(first_batch) try:
fest.add_flops_info() fest = flops_utils.FlopsEstimation(model)
model_str = str(model_flops) with fest.enable():
logger.info(model_str) model(inputs)
return model_str fest.add_flops_info()
model_str = str(model)
output_file = os.path.join(output_dir, "flops_str_mobilecv.txt")
with PathManager.open(output_file, "w") as f:
f.write(model_str)
logger.info(f"Flops info written to {output_file}")
except Exception:
logger.exception("Failed to estimate flops using mobile_cv's FlopsEstimation")
# 2. using d2/fvcore's flop counter
try:
flops = FlopCountAnalysis(model, inputs)
# 2.1: dump as model str
model_str = flop_count_str(flops)
output_file = os.path.join(output_dir, "flops_str_fvcore.txt")
with PathManager.open(output_file, "w") as f:
f.write(model_str)
logger.info(f"Flops info written to {output_file}")
# 2.2: dump as table
flops_table = flop_count_table(flops, max_depth=10)
output_file = os.path.join(output_dir, "flops_table_fvcore.txt")
with PathManager.open(output_file, "w") as f:
f.write(flops_table)
logger.info(f"Flops table written to {output_file}")
# 2.3: print a table with a shallow depth
flops_table = flop_count_table(flops, max_depth=3)
logger.info("Flops table:\n" + flops_table)
except Exception:
logger.exception("Failed to estimate flops using detectron2's FlopCountAnalysis")
# NOTE: the logging can be too long and messsy when printing flops multiple # NOTE: the logging can be too long and messsy when printing flops multiple
......
import tempfile
from d2go.utils.testing.rcnn_helper import RCNNBaseTestCases
from d2go.utils.flop_calculator import dump_flops_info
from d2go.utils.testing.data_loader_helper import create_fake_detection_data_loader
import os
class TestFlopCount(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
def test_flop_count(self):
size_divisibility = max(self.test_model.backbone.size_divisibility, 10)
h, w = size_divisibility, size_divisibility * 2
with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
inputs = next(iter(data_loader))
with tempfile.TemporaryDirectory(prefix="d2go_test") as output_dir:
dump_flops_info(self.test_model, inputs, output_dir)
for fname in ["flops_str_mobilecv", "flops_str_fvcore", "flops_table_fvcore"]:
outf = os.path.join(output_dir, fname + ".txt")
self.assertTrue(os.path.isfile(outf))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment