"src/vscode:/vscode.git/clone" did not exist on "92121fc66819444daba11bcb625826497a36c514"
Unverified Commit d1c63562 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Add custom op support of counter (#2795)

parent f892ed67
......@@ -12,7 +12,7 @@ except Exception as e:
raise
def count_flops_params(model: nn.Module, input_size, verbose=True):
def count_flops_params(model: nn.Module, input_size, custom_ops=None, verbose=True):
"""
Count FLOPs and Params of the given model.
This function would identify the mask on the module
......@@ -28,7 +28,10 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
target model.
input_size: list, tuple
the input shape of data
custom_ops: dict
a mapping of (module: custom operation)
the custom operation will overwrite the default operation.
for reference, please see ``custom_mask_ops``.
Returns
-------
......@@ -44,11 +47,14 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
inputs = torch.randn(input_size).to(device)
hook_module_list = []
if custom_ops is None:
custom_ops = {}
custom_mask_ops.update(custom_ops)
prev_m = None
for m in model.modules():
weight_mask = None
m_type = type(m)
if m_type in custom_ops:
if m_type in custom_mask_ops:
if isinstance(prev_m, PrunerModuleWrapper):
weight_mask = prev_m.weight_mask
......@@ -56,7 +62,7 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
hook_module_list.append(m)
prev_m = m
flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_ops, verbose=verbose)
flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_mask_ops, verbose=verbose)
for m in hook_module_list:
......@@ -74,7 +80,6 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
def count_convNd_mask(m, x, y):
"""
The forward hook to count FLOPs and Parameters of convolution operation.
Parameters
----------
m : torch.nn.Module
......@@ -101,7 +106,6 @@ def count_convNd_mask(m, x, y):
def count_linear_mask(m, x, y):
"""
The forward hook to count FLOPs and Parameters of linear transformation.
Parameters
----------
m : torch.nn.Module
......@@ -111,22 +115,21 @@ def count_linear_mask(m, x, y):
y : torch.Tensor
output data
"""
output_channel = y.size()[1]
output_size = torch.zeros(y.size()[2:]).numel()
output_channel = y.numel()
bias_flops = 1 if m.bias is not None else 0
if m.weight_mask is not None:
output_channel = m.weight_mask.sum() // m.in_features
total_ops = output_channel * output_size * (m.in_features + bias_flops)
total_ops = output_channel * (m.in_features + bias_flops)
m.total_ops += torch.DoubleTensor([int(total_ops)])
custom_ops = {
custom_mask_ops = {
nn.Conv1d: count_convNd_mask,
nn.Conv2d: count_convNd_mask,
nn.Conv3d: count_convNd_mask,
nn.Linear: count_linear_mask,
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment