Unverified Commit cda5fde6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add flops counter (#1127)

* add flops counter

* minor fix

* add forward_dummy() for most detectors

* add documentation for some tools
parent 225e9092
......@@ -162,6 +162,103 @@ pytorch [launch utility](https://pytorch.org/docs/stable/distributed_deprecated.
Usually it is slow if you do not have high speed networking like infiniband.
## Useful tools
### Analyze logs
You can plot loss/mAP curves given a training log file. Run `pip install seaborn` first to install the dependency.
![loss curve image](demo/loss_curve.png)
```shell
python tools/analyze_logs.py plot_curve [--keys ${KEYS}] [--title ${TITLE}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}]
```
Examples:
- Plot the classification loss of some run.
```shell
python tools/analyze_logs.py plot_curve log.json --keys loss_cls --legend loss_cls
```
- Plot the classification and regression loss of some run, and save the figure to a pdf.
```shell
python tools/analyze_logs.py plot_curve log.json --keys loss_cls loss_reg --out losses.pdf
```
- Compare the bbox mAP of two runs in the same figure.
```shell
python tools/analyze_logs.py plot_curve log1.json log2.json --keys bbox_mAP --legend run1 run2
```
You can also compute the average training speed.
```shell
python tools/analyze_logs.py cal_train_time ${CONFIG_FILE} [--include-outliers]
```
The output is expected to be like the following.
```
-----Analyze train time of work_dirs/some_exp/20190611_192040.log.json-----
slowest epoch 11, average time is 1.2024
fastest epoch 1, average time is 1.1909
time std over epochs is 0.0028
average iter time: 1.1959 s/iter
```
### Get the FLOPs and params (experimental)
We provide a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) to compute the FLOPs and params of a given model.
```shell
python tools/get_flops.py ${CONFIG_FILE} [--shape ${INPUT_SHAPE}]
```
You will get the result like this.
```
==============================
Input shape: (3, 1280, 800)
Flops: 239.32 GMac
Params: 37.74 M
==============================
```
**Note**: This tool is still experimental and we do not guarantee that the number is correct. You may well use the result for simple comparisons, but double check it before you adopt it in technical reports or papers.
(1) FLOPs are related to the input shape while parameters are not. The default input shape is (1, 3, 1280, 800).
(2) Some operators are not counted into FLOPs like GN and custom operators.
You can add support for new operators by modifying [`mmdet/utils/flops_counter.py`](mmdet/utils/flops_counter.py).
(3) The FLOPs of two-stage detectors is dependent on the number of proposals.
### Publish a model
Before you upload a model to AWS, you may want to
(1) convert model weights to CPU tensors, (2) delete the optimizer states and
(3) compute the hash of the checkpoint file and append the hash id to the filename.
```shell
python tools/publish_model.py ${INPUT_FILENAME} ${OUTPUT_FILENAME}
```
E.g.,
```shell
python tools/publish_model.py work_dirs/faster_rcnn/latest.pth faster_rcnn_r50_fpn_1x_20190801.pth
```
The final output filename will be `faster_rcnn_r50_fpn_1x_20190801-{hash id}.pth`.
### Test the robustness of detectors
Please refer to [ROBUSTNESS_BENCHMARKING.md](ROBUSTNESS_BENCHMARKING.md).
## How-to
### Use my own datasets
......
......@@ -117,6 +117,37 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
x = self.neck(x)
return x
def forward_dummy(self, img):
outs = ()
# backbone
x = self.extract_feat(img)
# rpn
if self.with_rpn:
rpn_outs = self.rpn_head(x)
outs = outs + (rpn_outs, )
proposals = torch.randn(1000, 4).cuda()
# bbox heads
rois = bbox2roi([proposals])
if self.with_bbox:
for i in range(self.num_stages):
bbox_feats = self.bbox_roi_extractor[i](
x[:self.bbox_roi_extractor[i].num_inputs], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred = self.bbox_head[i](bbox_feats)
outs = outs + (cls_score, bbox_pred)
# mask heads
if self.with_mask:
mask_rois = rois[:100]
for i in range(self.num_stages):
mask_feats = self.mask_roi_extractor[i](
x[:self.mask_roi_extractor[i].num_inputs], mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
mask_pred = self.mask_head[i](mask_feats)
outs = outs + (mask_pred, )
return outs
def forward_train(self,
img,
img_meta,
......
......@@ -12,6 +12,30 @@ class DoubleHeadRCNN(TwoStageDetector):
super().__init__(**kwargs)
self.reg_roi_scale_factor = reg_roi_scale_factor
def forward_dummy(self, img):
outs = ()
# backbone
x = self.extract_feat(img)
# rpn
if self.with_rpn:
rpn_outs = self.rpn_head(x)
outs = outs + (rpn_outs, )
proposals = torch.randn(1000, 4).cuda()
# bbox head
rois = bbox2roi([proposals])
bbox_cls_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
bbox_reg_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs],
rois,
roi_scale_factor=self.reg_roi_scale_factor)
if self.with_shared_head:
bbox_cls_feats = self.shared_head(bbox_cls_feats)
bbox_reg_feats = self.shared_head(bbox_reg_feats)
cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
outs += (cls_score, bbox_pred)
return outs
def forward_train(self,
img,
img_meta,
......
......@@ -80,6 +80,31 @@ class GridRCNN(TwoStageDetector):
sampling_result.pos_bboxes = new_bboxes
return sampling_results
def forward_dummy(self, img):
outs = ()
# backbone
x = self.extract_feat(img)
# rpn
if self.with_rpn:
rpn_outs = self.rpn_head(x)
outs = outs + (rpn_outs, )
proposals = torch.randn(1000, 4).cuda()
# bbox head
rois = bbox2roi([proposals])
bbox_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred = self.bbox_head(bbox_feats)
# grid head
grid_rois = rois[:100]
grid_feats = self.grid_roi_extractor(
x[:self.grid_roi_extractor.num_inputs], grid_rois)
if self.with_shared_head:
grid_feats = self.shared_head(grid_feats)
grid_pred = self.grid_head(grid_feats)
return rpn_outs, cls_score, bbox_pred, grid_pred
def forward_train(self,
img,
img_meta,
......
......@@ -153,6 +153,46 @@ class HybridTaskCascade(CascadeRCNN):
mask_pred = mask_head(mask_feats)
return mask_pred
def forward_dummy(self, img):
outs = ()
# backbone
x = self.extract_feat(img)
# rpn
if self.with_rpn:
rpn_outs = self.rpn_head(x)
outs = outs + (rpn_outs, )
proposals = torch.randn(1000, 4).cuda()
# semantic head
if self.with_semantic:
_, semantic_feat = self.semantic_head(x)
else:
semantic_feat = None
# bbox heads
rois = bbox2roi([proposals])
for i in range(self.num_stages):
cls_score, bbox_pred = self._bbox_forward_test(
i, x, rois, semantic_feat=semantic_feat)
outs = outs + (cls_score, bbox_pred)
# mask heads
if self.with_mask:
mask_rois = rois[:100]
mask_roi_extractor = self.mask_roi_extractor[-1]
mask_feats = mask_roi_extractor(
x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
if self.with_semantic and 'mask' in self.semantic_fusion:
mask_semantic_feat = self.semantic_roi_extractor(
[semantic_feat], mask_rois)
mask_feats += mask_semantic_feat
last_feat = None
for i in range(self.num_stages):
mask_head = self.mask_head[i]
if self.mask_info_flow:
mask_pred, last_feat = mask_head(mask_feats, last_feat)
else:
mask_pred = mask_head(mask_feats)
outs = outs + (mask_pred, )
return outs
def forward_train(self,
img,
img_meta,
......
......@@ -42,6 +42,9 @@ class MaskScoringRCNN(TwoStageDetector):
self.mask_iou_head = builder.build_head(mask_iou_head)
self.mask_iou_head.init_weights()
def forward_dummy(self, img):
raise NotImplementedError
# TODO: refactor forward_train in two stage to reduce code redundancy
def forward_train(self,
img,
......
......@@ -38,6 +38,11 @@ class RPN(BaseDetector, RPNTestMixin):
x = self.neck(x)
return x
def forward_dummy(self, img):
x = self.extract_feat(img)
rpn_outs = self.rpn_head(x)
return rpn_outs
def forward_train(self,
img,
img_meta,
......
......@@ -42,6 +42,11 @@ class SingleStageDetector(BaseDetector):
x = self.neck(x)
return x
def forward_dummy(self, img):
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs
def forward_train(self,
img,
img_metas,
......
......@@ -87,6 +87,35 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
x = self.neck(x)
return x
def forward_dummy(self, img):
outs = ()
# backbone
x = self.extract_feat(img)
# rpn
if self.with_rpn:
rpn_outs = self.rpn_head(x)
outs = outs + (rpn_outs, )
proposals = torch.randn(1000, 4).cuda()
# bbox head
rois = bbox2roi([proposals])
if self.with_bbox:
bbox_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred = self.bbox_head(bbox_feats)
outs = outs + (cls_score, bbox_pred)
# mask head
if self.with_mask:
mask_rois = rois[:100]
mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
mask_pred = self.mask_head(mask_feats)
outs = outs + (mask_pred, )
return outs
def forward_train(self,
img,
img_meta,
......
from .flops_counter import get_model_complexity_info
from .registry import Registry, build_from_cfg
__all__ = ['Registry', 'build_from_cfg']
__all__ = ['Registry', 'build_from_cfg', 'get_model_complexity_info']
# Modified from flops-counter.pytorch by Vladislav Sovrasov
# original repo: https://github.com/sovrasov/flops-counter.pytorch
# MIT License
# Copyright (c) 2018 Vladislav Sovrasov
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
_AvgPoolNd, _MaxPoolNd)
CONV_TYPES = (_ConvNd, )
DECONV_TYPES = (_ConvTransposeMixin, )
LINEAR_TYPES = (nn.Linear, )
POOLING_TYPES = (_AvgPoolNd, _MaxPoolNd, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd)
RELU_TYPES = (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)
BN_TYPES = (_BatchNorm, )
UPSAMPLE_TYPES = (nn.Upsample, )
SUPPORTED_TYPES = (
CONV_TYPES + DECONV_TYPES + LINEAR_TYPES + POOLING_TYPES + RELU_TYPES +
BN_TYPES + UPSAMPLE_TYPES)
def get_model_complexity_info(model,
input_res,
print_per_layer_stat=True,
as_strings=True,
input_constructor=None,
ost=sys.stdout):
assert type(input_res) is tuple
assert len(input_res) >= 2
flops_model = add_flops_counting_methods(model)
flops_model.eval().start_flops_count()
if input_constructor:
input = input_constructor(input_res)
_ = flops_model(**input)
else:
batch = torch.ones(()).new_empty(
(1, *input_res),
dtype=next(flops_model.parameters()).dtype,
device=next(flops_model.parameters()).device)
flops_model(batch)
if print_per_layer_stat:
print_model_with_flops(flops_model, ost=ost)
flops_count = flops_model.compute_average_flops_cost()
params_count = get_model_parameters_number(flops_model)
flops_model.stop_flops_count()
if as_strings:
return flops_to_string(flops_count), params_to_string(params_count)
return flops_count, params_count
def flops_to_string(flops, units='GMac', precision=2):
if units is None:
if flops // 10**9 > 0:
return str(round(flops / 10.**9, precision)) + ' GMac'
elif flops // 10**6 > 0:
return str(round(flops / 10.**6, precision)) + ' MMac'
elif flops // 10**3 > 0:
return str(round(flops / 10.**3, precision)) + ' KMac'
else:
return str(flops) + ' Mac'
else:
if units == 'GMac':
return str(round(flops / 10.**9, precision)) + ' ' + units
elif units == 'MMac':
return str(round(flops / 10.**6, precision)) + ' ' + units
elif units == 'KMac':
return str(round(flops / 10.**3, precision)) + ' ' + units
else:
return str(flops) + ' Mac'
def params_to_string(params_num):
if params_num // 10**6 > 0:
return str(round(params_num / 10**6, 2)) + ' M'
elif params_num // 10**3:
return str(round(params_num / 10**3, 2)) + ' k'
else:
return str(params_num)
def print_model_with_flops(model, units='GMac', precision=3, ost=sys.stdout):
total_flops = model.compute_average_flops_cost()
def accumulate_flops(self):
if is_supported_instance(self):
return self.__flops__ / model.__batch_counter__
else:
sum = 0
for m in self.children():
sum += m.accumulate_flops()
return sum
def flops_repr(self):
accumulated_flops_cost = self.accumulate_flops()
return ', '.join([
flops_to_string(
accumulated_flops_cost, units=units, precision=precision),
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
self.original_extra_repr()
])
def add_extra_repr(m):
m.accumulate_flops = accumulate_flops.__get__(m)
flops_extra_repr = flops_repr.__get__(m)
if m.extra_repr != flops_extra_repr:
m.original_extra_repr = m.extra_repr
m.extra_repr = flops_extra_repr
assert m.extra_repr != m.original_extra_repr
def del_extra_repr(m):
if hasattr(m, 'original_extra_repr'):
m.extra_repr = m.original_extra_repr
del m.original_extra_repr
if hasattr(m, 'accumulate_flops'):
del m.accumulate_flops
model.apply(add_extra_repr)
print(model, file=ost)
model.apply(del_extra_repr)
def get_model_parameters_number(model):
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params_num
def add_flops_counting_methods(net_main_module):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
net_main_module.start_flops_count = start_flops_count.__get__(
net_main_module)
net_main_module.stop_flops_count = stop_flops_count.__get__(
net_main_module)
net_main_module.reset_flops_count = reset_flops_count.__get__(
net_main_module)
net_main_module.compute_average_flops_cost = \
compute_average_flops_cost.__get__(net_main_module)
net_main_module.reset_flops_count()
# Adding variables necessary for masked flops computation
net_main_module.apply(add_flops_mask_variable_or_reset)
return net_main_module
def compute_average_flops_cost(self):
"""
A method that will be available after add_flops_counting_methods() is
called on a desired net object.
Returns current mean flops consumption per image.
"""
batches_count = self.__batch_counter__
flops_sum = 0
for module in self.modules():
if is_supported_instance(module):
flops_sum += module.__flops__
return flops_sum / batches_count
def start_flops_count(self):
"""
A method that will be available after add_flops_counting_methods() is
called on a desired net object.
Activates the computation of mean flops consumption per image.
Call it before you run the network.
"""
add_batch_counter_hook_function(self)
self.apply(add_flops_counter_hook_function)
def stop_flops_count(self):
"""
A method that will be available after add_flops_counting_methods() is
called on a desired net object.
Stops computing the mean flops consumption per image.
Call whenever you want to pause the computation.
"""
remove_batch_counter_hook_function(self)
self.apply(remove_flops_counter_hook_function)
def reset_flops_count(self):
"""
A method that will be available after add_flops_counting_methods() is
called on a desired net object.
Resets statistics computed so far.
"""
add_batch_counter_variables_or_reset(self)
self.apply(add_flops_counter_variable_or_reset)
def add_flops_mask(module, mask):
def add_flops_mask_func(module):
if isinstance(module, torch.nn.Conv2d):
module.__mask__ = mask
module.apply(add_flops_mask_func)
def remove_flops_mask(module):
module.apply(add_flops_mask_variable_or_reset)
def is_supported_instance(module):
if isinstance(module, SUPPORTED_TYPES):
return True
else:
return False
def empty_flops_counter_hook(module, input, output):
module.__flops__ += 0
def upsample_flops_counter_hook(module, input, output):
output_size = output[0]
batch_size = output_size.shape[0]
output_elements_count = batch_size
for val in output_size.shape[1:]:
output_elements_count *= val
module.__flops__ += int(output_elements_count)
def relu_flops_counter_hook(module, input, output):
active_elements_count = output.numel()
module.__flops__ += int(active_elements_count)
def linear_flops_counter_hook(module, input, output):
input = input[0]
batch_size = input.shape[0]
module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
def pool_flops_counter_hook(module, input, output):
input = input[0]
module.__flops__ += int(np.prod(input.shape))
def bn_flops_counter_hook(module, input, output):
module.affine
input = input[0]
batch_flops = np.prod(input.shape)
if module.affine:
batch_flops *= 2
module.__flops__ += int(batch_flops)
def deconv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
input_height, input_width = input.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
filters_per_channel = out_channels // groups
conv_per_position_flops = (
kernel_height * kernel_width * in_channels * filters_per_channel)
active_elements_count = batch_size * input_height * input_width
overall_conv_flops = conv_per_position_flops * active_elements_count
bias_flops = 0
if conv_module.bias is not None:
output_height, output_width = output.shape[2:]
bias_flops = out_channels * batch_size * output_height * output_height
overall_flops = overall_conv_flops + bias_flops
conv_module.__flops__ += int(overall_flops)
def conv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
output_dims = list(output.shape[2:])
kernel_dims = list(conv_module.kernel_size)
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
filters_per_channel = out_channels // groups
conv_per_position_flops = np.prod(
kernel_dims) * in_channels * filters_per_channel
active_elements_count = batch_size * np.prod(output_dims)
if conv_module.__mask__ is not None:
# (b, 1, h, w)
output_height, output_width = output.shape[2:]
flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height,
output_width)
active_elements_count = flops_mask.sum()
overall_conv_flops = conv_per_position_flops * active_elements_count
bias_flops = 0
if conv_module.bias is not None:
bias_flops = out_channels * active_elements_count
overall_flops = overall_conv_flops + bias_flops
conv_module.__flops__ += int(overall_flops)
def batch_counter_hook(module, input, output):
batch_size = 1
if len(input) > 0:
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = len(input)
else:
print('Warning! No positional inputs found for a module, '
'assuming batch size is 1.')
module.__batch_counter__ += batch_size
def add_batch_counter_variables_or_reset(module):
module.__batch_counter__ = 0
def add_batch_counter_hook_function(module):
if hasattr(module, '__batch_counter_handle__'):
return
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle
def remove_batch_counter_hook_function(module):
if hasattr(module, '__batch_counter_handle__'):
module.__batch_counter_handle__.remove()
del module.__batch_counter_handle__
def add_flops_counter_variable_or_reset(module):
if is_supported_instance(module):
module.__flops__ = 0
def add_flops_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
return
if isinstance(module, CONV_TYPES):
handle = module.register_forward_hook(conv_flops_counter_hook)
elif isinstance(module, RELU_TYPES):
handle = module.register_forward_hook(relu_flops_counter_hook)
elif isinstance(module, LINEAR_TYPES):
handle = module.register_forward_hook(linear_flops_counter_hook)
elif isinstance(module, POOLING_TYPES):
handle = module.register_forward_hook(pool_flops_counter_hook)
elif isinstance(module, BN_TYPES):
handle = module.register_forward_hook(bn_flops_counter_hook)
elif isinstance(module, UPSAMPLE_TYPES):
handle = module.register_forward_hook(upsample_flops_counter_hook)
elif isinstance(module, DECONV_TYPES):
handle = module.register_forward_hook(deconv_flops_counter_hook)
else:
handle = module.register_forward_hook(empty_flops_counter_hook)
module.__flops_handle__ = handle
def remove_flops_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
module.__flops_handle__.remove()
del module.__flops_handle__
# --- Masked flops counting
# Also being run in the initialization
def add_flops_mask_variable_or_reset(module):
if is_supported_instance(module):
module.__mask__ = None
import argparse
from mmcv import Config
from mmdet.models import build_detector
from mmdet.utils import get_model_complexity_info
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1280, 800],
help='input image size')
args = parser.parse_args()
return args
def main():
args = parse_args()
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = Config.fromfile(args.config)
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
model.eval()
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))
flops, params = get_model_complexity_info(model, input_shape)
split_line = '=' * 30
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
split_line, input_shape, flops, params))
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment