"docs/source/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "808b06557078626556ac855e196cbd788a9a9857"
Unverified Commit d1c63562 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Add custom op support of counter (#2795)

parent f892ed67
...@@ -12,7 +12,7 @@ except Exception as e: ...@@ -12,7 +12,7 @@ except Exception as e:
raise raise
def count_flops_params(model: nn.Module, input_size, verbose=True): def count_flops_params(model: nn.Module, input_size, custom_ops=None, verbose=True):
""" """
Count FLOPs and Params of the given model. Count FLOPs and Params of the given model.
This function would identify the mask on the module This function would identify the mask on the module
...@@ -28,7 +28,10 @@ def count_flops_params(model: nn.Module, input_size, verbose=True): ...@@ -28,7 +28,10 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
target model. target model.
input_size: list, tuple input_size: list, tuple
the input shape of data the input shape of data
custom_ops: dict
a mapping of (module: custom operation)
the custom operation will overwrite the default operation.
for reference, please see ``custom_mask_ops``.
Returns Returns
------- -------
...@@ -44,11 +47,14 @@ def count_flops_params(model: nn.Module, input_size, verbose=True): ...@@ -44,11 +47,14 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
inputs = torch.randn(input_size).to(device) inputs = torch.randn(input_size).to(device)
hook_module_list = [] hook_module_list = []
if custom_ops is None:
custom_ops = {}
custom_mask_ops.update(custom_ops)
prev_m = None prev_m = None
for m in model.modules(): for m in model.modules():
weight_mask = None weight_mask = None
m_type = type(m) m_type = type(m)
if m_type in custom_ops: if m_type in custom_mask_ops:
if isinstance(prev_m, PrunerModuleWrapper): if isinstance(prev_m, PrunerModuleWrapper):
weight_mask = prev_m.weight_mask weight_mask = prev_m.weight_mask
...@@ -56,7 +62,7 @@ def count_flops_params(model: nn.Module, input_size, verbose=True): ...@@ -56,7 +62,7 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
hook_module_list.append(m) hook_module_list.append(m)
prev_m = m prev_m = m
flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_ops, verbose=verbose) flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_mask_ops, verbose=verbose)
for m in hook_module_list: for m in hook_module_list:
...@@ -74,7 +80,6 @@ def count_flops_params(model: nn.Module, input_size, verbose=True): ...@@ -74,7 +80,6 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
def count_convNd_mask(m, x, y): def count_convNd_mask(m, x, y):
""" """
The forward hook to count FLOPs and Parameters of convolution operation. The forward hook to count FLOPs and Parameters of convolution operation.
Parameters Parameters
---------- ----------
m : torch.nn.Module m : torch.nn.Module
...@@ -101,7 +106,6 @@ def count_convNd_mask(m, x, y): ...@@ -101,7 +106,6 @@ def count_convNd_mask(m, x, y):
def count_linear_mask(m, x, y): def count_linear_mask(m, x, y):
""" """
The forward hook to count FLOPs and Parameters of linear transformation. The forward hook to count FLOPs and Parameters of linear transformation.
Parameters Parameters
---------- ----------
m : torch.nn.Module m : torch.nn.Module
...@@ -111,22 +115,21 @@ def count_linear_mask(m, x, y): ...@@ -111,22 +115,21 @@ def count_linear_mask(m, x, y):
y : torch.Tensor y : torch.Tensor
output data output data
""" """
output_channel = y.size()[1] output_channel = y.numel()
output_size = torch.zeros(y.size()[2:]).numel()
bias_flops = 1 if m.bias is not None else 0 bias_flops = 1 if m.bias is not None else 0
if m.weight_mask is not None: if m.weight_mask is not None:
output_channel = m.weight_mask.sum() // m.in_features output_channel = m.weight_mask.sum() // m.in_features
total_ops = output_channel * output_size * (m.in_features + bias_flops) total_ops = output_channel * (m.in_features + bias_flops)
m.total_ops += torch.DoubleTensor([int(total_ops)]) m.total_ops += torch.DoubleTensor([int(total_ops)])
custom_ops = { custom_mask_ops = {
nn.Conv1d: count_convNd_mask, nn.Conv1d: count_convNd_mask,
nn.Conv2d: count_convNd_mask, nn.Conv2d: count_convNd_mask,
nn.Conv3d: count_convNd_mask, nn.Conv3d: count_convNd_mask,
nn.Linear: count_linear_mask, nn.Linear: count_linear_mask,
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment