"...bindings/c/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "6e09681e0ec52b6a75dc914c7b3bb124899b80f8"
Unverified Commit 32efe8e7 authored by Super Daniel's avatar Super Daniel Committed by GitHub
Browse files

[fx] add profiler for fx nodes. (#1480)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] add rules to linearize computation graphs for searching. (#2)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] fix inconsistencies.

* [fx] fix MetaInfoProp.

* [fx] fix MetaInfoProp.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] fix error in tests.

* [fx] unfix bug.

* [fx] unfix bug.
parent d39e11df
...@@ -73,10 +73,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule: ...@@ -73,10 +73,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
y = 0 y = 0
prev_idx = 2 prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes): for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'activation_size') temp += getattr(n, '__activation__')
y = max(y, temp) y = max(y, temp)
if temp > b and n in ckpt_nodes: if temp > b and n in ckpt_nodes:
x += getattr(n, 'activation_size') x += getattr(n, '__activation__')
temp = 0 temp = 0
ckpt_intv.append((prev_idx, idx + 1)) ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1 prev_idx = idx + 1
......
from operator import add, getitem
import torch import torch
import torch.fx import torch.fx
from torch.fx.node import Node, map_aggregate from torch.fx.node import Node, map_aggregate, Argument, Target
from typing import Any, Tuple, NamedTuple, Optional, Dict from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce from functools import reduce
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list from torch.fx.immutable_collections import immutable_dict, immutable_list
from colossalai.fx.profiler import MetaProfile, profile_function, profile_module, calculate_activation_size, profile_method
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
...@@ -36,47 +38,11 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: ...@@ -36,47 +38,11 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor) return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)
def _compute_activation_size(node_metadata: any) -> int:
"""
Compute numel of a node with ``tensor_meta`` attribute.
"""
node_numel = 0
if isinstance(node_metadata, TensorMetadata):
node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size()
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
node_numel += _compute_activation_size(value_list)
else:
for element in node_metadata:
node_numel += _compute_activation_size(element)
return node_numel
def _map_aggregate(arg, fn):
"""
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
"""
if isinstance(arg, torch.Size):
return fn(arg)
if isinstance(arg, tuple):
return tuple(map_aggregate(elem, fn) for elem in arg)
elif isinstance(arg, list):
return immutable_list(map_aggregate(elem, fn) for elem in arg)
elif isinstance(arg, dict):
return immutable_dict((k, map_aggregate(v, fn)) for k, v in arg.items())
elif isinstance(arg, slice):
return slice(map_aggregate(arg.start, fn), map_aggregate(arg.stop, fn), map_aggregate(arg.step, fn))
else:
return fn(arg)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
class MetaInfoProp(torch.fx.Interpreter): class MetaInfoProp(torch.fx.Interpreter):
""" """
Execute an FX graph Node-by-Node and Execute an FX graph Node-by-Node with meta tensor and
record the shape and type of the result record the shape, FLOPs, MACs and type of the result
into the corresponding node. into the corresponding node.
Usage: Usage:
...@@ -104,9 +70,32 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -104,9 +70,32 @@ class MetaInfoProp(torch.fx.Interpreter):
""" """
@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
"""
for elem in args:
if isinstance(elem, torch.Tensor):
assert elem.is_meta, "Input torch.Tensor are assumed to appear with device='meta'"
return super().run(*args, initial_env, enable_io_processing)
@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any: def run_node(self, n: Node) -> Any:
# TODO: We might run_node(n) with meta data, and count FLOPS for each node """
result = super().run_node(n) Run a specific node ``n`` and return the result.
Calls into placeholder, get_attr, call_function,
call_method, call_module, or output depending
on ``node.op``
Args:
n (Node): The Node to execute
Returns:
Any: The result of executing ``n``
"""
result, profile = super().run_node(n)
profile: MetaProfile
def extract_tensor_meta(obj): def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
...@@ -114,29 +103,139 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -114,29 +103,139 @@ class MetaInfoProp(torch.fx.Interpreter):
else: else:
return TensorMetadata(None, None, False, None, 0, False) return TensorMetadata(None, None, False, None, 0, False)
meta = _map_aggregate(result, extract_tensor_meta) meta = map_aggregate(result, extract_tensor_meta)
n.meta['tensor_meta'] = meta n.meta['tensor_meta'] = meta
total_activation_size = 0 # TODO: the attribute node_size should be removed in the future
total_param_size = 0 setattr(n, 'node_size', profile.param + profile.activation)
if n.op == 'call_module': setattr(n, '__param__', profile.param)
target_module = n.graph.owning_module.get_submodule(n.target) setattr(n, '__activation__', profile.activation)
if not getattr(target_module, 'inplace', False): setattr(n, '__flops__', profile.flops)
total_activation_size = _compute_activation_size(n.meta['tensor_meta']) setattr(n, '__macs__', profile.macs)
for param in target_module.parameters():
total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
elif n.op == 'call_function':
if 'inplace' not in n.kwargs:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
else:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
setattr(n, 'node_size', total_activation_size + total_param_size)
setattr(n, 'param_size', total_param_size)
setattr(n, 'activation_size', total_activation_size)
n.meta['type'] = type(result) n.meta['type'] = type(result)
return result return result
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
arguments passed to ``run`` and this method returns
next() on that iterator.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Returns:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
# A get_attr node never has parameters, activations, FLOPs, or MACs
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
return profile_method(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return profile_module(submod)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
Any: The return value referenced by the output node
"""
return args[0], MetaProfile(0, 0, 0, 0)
def propagate(self, *args): def propagate(self, *args):
""" """
Run `module` via interpretation and return the result and Run `module` via interpretation and return the result and
......
from .registry import *
from .profiler_function import *
from .profiler_module import *
from .utils import *
from .activation_function import *
from .arithmetic import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .python_ops import *
from .torch_ops import *
from typing import Tuple
import torch
from ..registry import meta_profiler_function
# TODO: different activation has different FLOPs count, currently unused.
_multiplier = {
torch.nn.functional.relu: 1,
torch.nn.functional.prelu: 4,
torch.nn.functional.sigmoid: 4,
torch.nn.functional.tanh: 5,
torch.nn.functional.leaky_relu: 3,
torch.nn.functional.elu: 4,
torch.nn.functional.relu6: 2,
torch.nn.functional.gelu: 9,
}
@meta_profiler_function.register(torch.nn.functional.leaky_relu)
@meta_profiler_function.register(torch.nn.functional.elu)
@meta_profiler_function.register(torch.nn.functional.gelu)
@meta_profiler_function.register(torch.nn.functional.relu6)
@meta_profiler_function.register(torch.nn.functional.prelu)
@meta_profiler_function.register(torch.nn.functional.relu)
@meta_profiler_function.register(torch.nn.functional.sigmoid)
@meta_profiler_function.register(torch.nn.functional.tanh)
def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs
from typing import Any, Optional, Tuple, Union
import torch
from ..registry import meta_profiler_function
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
def _elementwise_flops_compute(input, other):
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
if not torch.is_tensor(input):
if torch.is_tensor(other):
return _prod(other.shape), 0
else:
return 1, 0
elif not torch.is_tensor(other):
return _prod(input.shape), 0
else:
dim_input = len(input.shape)
dim_other = len(other.shape)
max_dim = max(dim_input, dim_other)
final_shape = []
for i in range(max_dim):
in_i = input.shape[i] if i < dim_input else 1
ot_i = other.shape[i] if i < dim_other else 1
if in_i > ot_i:
final_shape.append(in_i)
else:
final_shape.append(ot_i)
flops = _prod(final_shape)
return flops, 0
@meta_profiler_function.register(torch.add)
@meta_profiler_function.register('add') # for built-in op +
@meta_profiler_function.register('iadd') # for built-in op +=
@meta_profiler_function.register('sub') # for built-in op -
@meta_profiler_function.register('isub') # for built-in op -=
@meta_profiler_function.register('mul') # for built-in op *
@meta_profiler_function.register('imul') # for built-in op *=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)
@meta_profiler_function.register(torch.abs)
def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs
@meta_profiler_function.register(torch.matmul)
@meta_profiler_function.register('matmul') # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs
@meta_profiler_function.register(torch.bmm)
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs
@meta_profiler_function.register(torch.var_mean)
def torch_var_mean(input: torch.Tensor,
dim: Union[int, Tuple[int, ...]],
unbiased: Optional[bool] = True,
keepdim: Optional[bool] = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
assert out is None, 'saving to out is not supported yet'
flops = input.numel() * 3
macs = 0
return flops, macs
import torch
from typing import Optional
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.embedding)
def torch_nn_functional_embedding(
input: torch.Tensor,
weight: torch.Tensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
) -> torch.Tensor:
# F.embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops = 0
macs = 0
return flops, macs
from typing import Tuple
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.linear)
def torch_nn_linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> Tuple[int, int]:
out_features = weight.shape[0]
macs = torch.numel(input) * out_features
flops = 2 * macs
if bias is not None:
flops += bias.numel()
return flops, macs
from typing import List, Optional, Tuple
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.instance_norm)
def torch_nn_func_instancenorm(
input: torch.Tensor,
running_mean: Optional[torch.Tensor] = None,
running_var: Optional[torch.Tensor] = None,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_input_stats: bool = True,
momentum: float = 0.1,
eps: float = 1e-5,
):
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
@meta_profiler_function.register(torch.nn.functional.group_norm)
def torch_nn_func_groupnorm(input: torch.Tensor,
num_groups: int,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
@meta_profiler_function.register(torch.nn.functional.layer_norm)
def torch_nn_func_layernorm(
input: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
@meta_profiler_function.register(torch.nn.functional.batch_norm)
def torch_nn_func_batchnorm(
input: torch.Tensor,
running_mean: Optional[torch.Tensor],
running_var: Optional[torch.Tensor],
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
) -> Tuple[int, int]:
has_affine = weight is not None
if training:
flops = input.numel() * (2 if has_affine else 1)
else:
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
from typing import Tuple, Union
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.avg_pool1d)
@meta_profiler_function.register(torch.nn.functional.avg_pool2d)
@meta_profiler_function.register(torch.nn.functional.avg_pool3d)
@meta_profiler_function.register(torch.nn.functional.max_pool1d)
@meta_profiler_function.register(torch.nn.functional.max_pool2d)
@meta_profiler_function.register(torch.nn.functional.max_pool3d)
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool1d)
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool2d)
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool3d)
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool1d)
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool2d)
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool3d)
def torch_nn_func_pooling(input: torch.Tensor, *args, **kwargs) -> Tuple[int, int]:
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
flops = input.numel()
macs = 0
return flops, macs
import operator
from typing import Any, Tuple
import torch
from ..registry import meta_profiler_function
from colossalai.fx.proxy import ColoProxy
@meta_profiler_function.register(operator.getitem)
def operator_getitem(a: Any, b: Any) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs
from typing import Any, Optional, Tuple
import torch
from ..registry import meta_profiler_function
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
@meta_profiler_function.register(torch.arange)
@meta_profiler_function.register(torch.finfo)
@meta_profiler_function.register(torch.permute)
@meta_profiler_function.register(torch.Tensor.permute)
@meta_profiler_function.register(torch.Tensor.repeat)
@meta_profiler_function.register(torch.index_select)
@meta_profiler_function.register(torch.Tensor.index_select)
@meta_profiler_function.register(torch.squeeze)
@meta_profiler_function.register(torch.Tensor.squeeze)
@meta_profiler_function.register(torch.unsqueeze)
@meta_profiler_function.register(torch.Tensor.unsqueeze)
@meta_profiler_function.register(torch.cat)
@meta_profiler_function.register(torch.concat)
@meta_profiler_function.register(torch.repeat_interleave)
@meta_profiler_function.register(torch.Tensor.repeat_interleave)
@meta_profiler_function.register(torch.flatten)
@meta_profiler_function.register(torch.Tensor.flatten)
@meta_profiler_function.register(torch.roll)
@meta_profiler_function.register(torch.full)
@meta_profiler_function.register(torch.Tensor.cpu)
@meta_profiler_function.register(torch.Tensor.cuda)
def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs
@meta_profiler_function.register(torch.where)
def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
flops = condition.numel()
macs = 0
return flops, macs
@meta_profiler_function.register(torch.max)
def torch_max(input: torch.Tensor,
dim: int = None,
keepdim: bool = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = 0
assert out is None, 'assigning value to out is not supported yet'
if dim is not None:
shape = list(input.shape)
shape.pop(int(dim))
flops = _prod(shape), macs
return flops, macs
else:
flops = input.numel()
return flops, macs
from .activation_function import *
from .convolution import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .rnn import *
from typing import Tuple
import torch
from ..registry import meta_profiler_module
# TODO: different activation has different FLOPs count, currently unused.
_multiplier = {
torch.nn.ReLU: 1,
torch.nn.PReLU: 4,
torch.nn.Sigmoid: 4,
torch.nn.Tanh: 5,
torch.nn.LeakyReLU: 3,
torch.nn.ELU: 4,
torch.nn.ReLU6: 2,
torch.nn.GELU: 9,
}
@meta_profiler_module.register(torch.nn.ELU)
@meta_profiler_module.register(torch.nn.LeakyReLU)
@meta_profiler_module.register(torch.nn.ReLU)
@meta_profiler_module.register(torch.nn.GELU)
@meta_profiler_module.register(torch.nn.Sigmoid)
@meta_profiler_module.register(torch.nn.Tanh)
@meta_profiler_module.register(torch.nn.ReLU6)
@meta_profiler_module.register(torch.nn.PReLU)
def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs
import math
from typing import Tuple
import torch
from ..registry import meta_profiler_module
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
@meta_profiler_module.register(torch.nn.Conv1d)
def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += num_elem
return flops, macs
@meta_profiler_module.register(torch.nn.Conv2d)
def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += num_elem
return flops, macs
@meta_profiler_module.register(torch.nn.Conv3d)
def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += num_elem
return flops, macs
@meta_profiler_module.register(torch.nn.ConvTranspose1d)
def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(
input.shape
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += _prod(result_shape)
return flops, macs
@meta_profiler_module.register(torch.nn.ConvTranspose2d)
def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(input.shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += _prod(result_shape)
return flops, macs
@meta_profiler_module.register(torch.nn.ConvTranspose3d)
def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
(self.kernel_size[2] - 1) + self.output_padding[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(input.shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += _prod(result_shape)
return flops, macs
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Embedding)
def torch_nn_embedding(self: torch.nn.Embedding, input: torch.Tensor) -> Tuple[int, int]:
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops = 0
macs = 0
return flops, macs
\ No newline at end of file
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Linear)
def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]:
out_features = self.weight.shape[0]
macs = torch.numel(input) * out_features
flops = 2 * macs
if self.bias is not None:
flops += self.bias.numel()
return flops, macs
from typing import Tuple, Union
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.InstanceNorm1d)
@meta_profiler_module.register(torch.nn.InstanceNorm2d)
@meta_profiler_module.register(torch.nn.InstanceNorm3d)
@meta_profiler_module.register(torch.nn.LayerNorm)
@meta_profiler_module.register(torch.nn.GroupNorm)
@meta_profiler_module.register(torch.nn.BatchNorm1d)
@meta_profiler_module.register(torch.nn.BatchNorm2d)
@meta_profiler_module.register(torch.nn.BatchNorm3d)
def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]:
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
has_affine = self.weight is not None
if self.training:
flops = input.numel() * (2 if has_affine else 1)
else:
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
try:
import apex
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
except (ImportError, AttributeError):
pass
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.AvgPool1d)
@meta_profiler_module.register(torch.nn.AvgPool2d)
@meta_profiler_module.register(torch.nn.AvgPool3d)
@meta_profiler_module.register(torch.nn.MaxPool1d)
@meta_profiler_module.register(torch.nn.MaxPool2d)
@meta_profiler_module.register(torch.nn.MaxPool3d)
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool1d)
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool1d)
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool2d)
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool2d)
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool3d)
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool3d)
def torch_nn_pooling(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
flops = input.numel()
macs = 0
return flops, macs
import torch
from ..registry import meta_profiler_module
from typing import Optional, Tuple
# TODO: calculate rnn FLOPs
@meta_profiler_module.register(torch.nn.GRU)
@meta_profiler_module.register(torch.nn.RNN)
def torch_nn_rnn(self: torch.nn.Module, input: torch.Tensor, hx: torch.Tensor) -> Tuple[int, int]:
raise NotImplementedError
flops = 0
macs = 0
return flops, macs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment