Commit 08f2920e authored by zhuwenwen's avatar zhuwenwen
Browse files

init colossalai, support dtk2304

parent da3f0934
Pipeline #237 failed with stages
in 0 seconds
from typing import Tuple
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.linear)
def torch_nn_linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> Tuple[int, int]:
out_features = weight.shape[0]
macs = torch.numel(input) * out_features
flops = 2 * macs
if bias is not None:
flops += bias.numel()
return flops, macs
from typing import List, Optional, Tuple
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.instance_norm)
def torch_nn_func_instancenorm(
input: torch.Tensor,
running_mean: Optional[torch.Tensor] = None,
running_var: Optional[torch.Tensor] = None,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_input_stats: bool = True,
momentum: float = 0.1,
eps: float = 1e-5,
):
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
@meta_profiler_function.register(torch.nn.functional.group_norm)
def torch_nn_func_groupnorm(input: torch.Tensor,
num_groups: int,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
@meta_profiler_function.register(torch.nn.functional.layer_norm)
def torch_nn_func_layernorm(
input: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
@meta_profiler_function.register(torch.nn.functional.batch_norm)
def torch_nn_func_batchnorm(
input: torch.Tensor,
running_mean: Optional[torch.Tensor],
running_var: Optional[torch.Tensor],
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
) -> Tuple[int, int]:
has_affine = weight is not None
if training:
flops = input.numel() * (2 if has_affine else 1)
else:
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
from typing import Tuple, Union
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.avg_pool1d)
@meta_profiler_function.register(torch.nn.functional.avg_pool2d)
@meta_profiler_function.register(torch.nn.functional.avg_pool3d)
@meta_profiler_function.register(torch.nn.functional.max_pool1d)
@meta_profiler_function.register(torch.nn.functional.max_pool2d)
@meta_profiler_function.register(torch.nn.functional.max_pool3d)
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool1d)
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool2d)
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool3d)
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool1d)
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool2d)
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool3d)
def torch_nn_func_pooling(input: torch.Tensor, *args, **kwargs) -> Tuple[int, int]:
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
flops = input.numel()
macs = 0
return flops, macs
import operator
from typing import Any, Tuple
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(operator.getitem)
def operator_getitem(a: Any, b: Any) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs
@meta_profiler_function.register(getattr)
def python_getattr(a: Any, b: Any) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs
from functools import reduce
import operator
from typing import Any, Optional, Tuple
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.arange)
@meta_profiler_function.register(torch.finfo)
@meta_profiler_function.register(torch.permute)
@meta_profiler_function.register(torch.Tensor.permute)
@meta_profiler_function.register(torch.Tensor.repeat)
@meta_profiler_function.register(torch.index_select)
@meta_profiler_function.register(torch.Tensor.index_select)
@meta_profiler_function.register(torch.squeeze)
@meta_profiler_function.register(torch.Tensor.squeeze)
@meta_profiler_function.register(torch.unsqueeze)
@meta_profiler_function.register(torch.Tensor.unsqueeze)
@meta_profiler_function.register(torch.cat)
@meta_profiler_function.register(torch.concat)
@meta_profiler_function.register(torch.repeat_interleave)
@meta_profiler_function.register(torch.Tensor.repeat_interleave)
@meta_profiler_function.register(torch.flatten)
@meta_profiler_function.register(torch.Tensor.flatten)
@meta_profiler_function.register(torch.roll)
@meta_profiler_function.register(torch.full)
@meta_profiler_function.register(torch.Tensor.cpu)
@meta_profiler_function.register(torch.Tensor.cuda)
@meta_profiler_function.register(torch._assert)
def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs
@meta_profiler_function.register(torch.where)
def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
flops = condition.numel()
macs = 0
return flops, macs
@meta_profiler_function.register(torch.max)
def torch_max(input: torch.Tensor,
dim: int = None,
keepdim: bool = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = 0
assert out is None, 'assigning value to out is not supported yet'
if dim is not None:
shape = list(input.shape)
shape.pop(int(dim))
flops = reduce(operator.mul, shape), macs
return flops, macs
else:
flops = input.numel()
return flops, macs
from .activation_function import *
from .attention import *
from .convolution import *
from .dropout import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .rnn import *
from .torch_op import *
from typing import Tuple
import torch
from ..registry import meta_profiler_module
# TODO: different activation has different FLOPs count, currently unused.
_multiplier = {
torch.nn.ReLU: 1,
torch.nn.PReLU: 4,
torch.nn.Sigmoid: 4,
torch.nn.Tanh: 5,
torch.nn.LeakyReLU: 3,
torch.nn.ELU: 4,
torch.nn.ReLU6: 2,
torch.nn.GELU: 9,
torch.nn.Hardswish: 5,
torch.nn.Hardsigmoid: 4,
}
@meta_profiler_module.register(torch.nn.ELU)
@meta_profiler_module.register(torch.nn.LeakyReLU)
@meta_profiler_module.register(torch.nn.ReLU)
@meta_profiler_module.register(torch.nn.GELU)
@meta_profiler_module.register(torch.nn.Sigmoid)
@meta_profiler_module.register(torch.nn.Tanh)
@meta_profiler_module.register(torch.nn.ReLU6)
@meta_profiler_module.register(torch.nn.PReLU)
@meta_profiler_module.register(torch.nn.Hardswish)
@meta_profiler_module.register(torch.nn.Hardsigmoid)
def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs
from typing import Optional, Tuple
import torch
from ..registry import meta_profiler_module
# TODO: This is hard to compute memory cost
@meta_profiler_module.register(torch.nn.MultiheadAttention)
def torch_nn_msa(self: torch.nn.MultiheadAttention,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[torch.Tensor] = None,
average_attn_weights: bool = True) -> Tuple[int, int]:
if getattr(self, 'batch_first', False):
batch_size = query.shape[0]
len_idx = 1
else:
batch_size = query.shape[1]
len_idx = 0
dim_idx = 2
qdim = query.shape[dim_idx]
kdim = key.shape[dim_idx]
vdim = value.shape[dim_idx]
qlen = query.shape[len_idx]
klen = key.shape[len_idx]
vlen = value.shape[len_idx]
num_heads = self.num_heads
assert qdim == self.embed_dim
if self.kdim is None:
assert kdim == qdim
if self.vdim is None:
assert vdim == qdim
flops = 0
macs = 0
# Q scaling
flops += qlen * qdim
# Initial projections
flops += 2 * ((qlen * qdim * qdim) # QW
+ (klen * kdim * kdim) # KW
+ (vlen * vdim * vdim) # VW
)
macs += ((qlen * qdim * qdim) # QW
+ (klen * kdim * kdim) # KW
+ (vlen * vdim * vdim) # VW
)
if self.in_proj_bias is not None:
flops += (qlen + klen + vlen) * qdim
# attention heads: scale, matmul, softmax, matmul
qk_head_dim = qdim // num_heads
v_head_dim = vdim // num_heads
head_flops = (
2 * (qlen * klen * qk_head_dim) # QK^T
+ (qlen * klen) # softmax
+ 2 * (qlen * klen * v_head_dim) # AV
)
head_macs = ((qlen * klen * qk_head_dim) # QK^T
+ 2 * (qlen * klen * v_head_dim) # AV
)
flops += num_heads * head_flops
macs += num_heads * head_flops
# final projection, bias is always enabled
flops += qlen * vdim * (vdim + 1)
flops *= batch_size
macs *= batch_size
return flops, macs
import operator
from functools import reduce
import math
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Conv1d)
def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += num_elem
return flops, macs
@meta_profiler_module.register(torch.nn.Conv2d)
def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += num_elem
return flops, macs
@meta_profiler_module.register(torch.nn.Conv3d)
def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += num_elem
return flops, macs
@meta_profiler_module.register(torch.nn.ConvTranspose1d)
def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(
operator.mul, input.shape
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += reduce(operator.mul, result_shape)
return flops, macs
@meta_profiler_module.register(torch.nn.ConvTranspose2d)
def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, input.shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += reduce(operator.mul, result_shape)
return flops, macs
@meta_profiler_module.register(torch.nn.ConvTranspose3d)
def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
(self.kernel_size[2] - 1) + self.output_padding[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, input.shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += reduce(operator.mul, result_shape)
return flops, macs
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Dropout)
def torch_nn_dropout(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops = 0
macs = 0
return flops, macs
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Embedding)
def torch_nn_embedding(self: torch.nn.Embedding, input: torch.Tensor) -> Tuple[int, int]:
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops = 0
macs = 0
return flops, macs
\ No newline at end of file
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Linear)
@meta_profiler_module.register(torch.nn.modules.linear.NonDynamicallyQuantizableLinear)
def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]:
out_features = self.weight.shape[0]
macs = input.numel() * out_features
flops = 2 * macs
if self.bias is not None:
flops += self.bias.numel()
return flops, macs
from typing import Tuple, Union
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.InstanceNorm1d)
@meta_profiler_module.register(torch.nn.InstanceNorm2d)
@meta_profiler_module.register(torch.nn.InstanceNorm3d)
@meta_profiler_module.register(torch.nn.LayerNorm)
@meta_profiler_module.register(torch.nn.GroupNorm)
@meta_profiler_module.register(torch.nn.BatchNorm1d)
@meta_profiler_module.register(torch.nn.BatchNorm2d)
@meta_profiler_module.register(torch.nn.BatchNorm3d)
def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]:
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
has_affine = self.weight is not None
if self.training:
flops = input.numel() * (2 if has_affine else 1)
else:
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
try:
import apex
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
except (ImportError, AttributeError):
pass
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.AvgPool1d)
@meta_profiler_module.register(torch.nn.AvgPool2d)
@meta_profiler_module.register(torch.nn.AvgPool3d)
@meta_profiler_module.register(torch.nn.MaxPool1d)
@meta_profiler_module.register(torch.nn.MaxPool2d)
@meta_profiler_module.register(torch.nn.MaxPool3d)
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool1d)
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool1d)
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool2d)
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool2d)
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool3d)
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool3d)
def torch_nn_pooling(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
flops = input.numel()
macs = 0
return flops, macs
from functools import reduce
import operator
import torch
from ..registry import meta_profiler_module
from typing import Optional, Tuple, Union
def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor,
w_hh: torch.Tensor) -> Tuple[int, int]:
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
# matrix matrix mult ih state and internal state
macs += reduce(operator.mul, w_ih.shape)
flops += 2 * reduce(operator.mul, w_ih.shape)
# matrix matrix mult hh state and internal state
macs += reduce(operator.mul, w_hh.shape)
flops += 2 * reduce(operator.mul, w_hh.shape)
if isinstance(module, (torch.nn.RNN, torch.nn.RNNCell)):
# add both operations
flops += module.hidden_size
elif isinstance(module, (torch.nn.GRU, torch.nn.GRUCell)):
# hadamard of r
flops += module.hidden_size
# adding operations from both states
flops += module.hidden_size * 3
# last two hadamard product and add
flops += module.hidden_size * 3
elif isinstance(module, (torch.nn.LSTM, torch.nn.LSTMCell)):
# adding operations from both states
flops += module.hidden_size * 4
# two hadamard product and add for C state
flops += module.hidden_size * 3
# final hadamard
flops += module.hidden_size * 3
return flops, macs
@meta_profiler_module.register(torch.nn.LSTM)
@meta_profiler_module.register(torch.nn.GRU)
@meta_profiler_module.register(torch.nn.RNN)
def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = 0
macs = 0
for i in range(self.num_layers):
w_ih = self.__getattr__('weight_ih_l' + str(i))
w_hh = self.__getattr__('weight_hh_l' + str(i))
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
b_ih = self.__getattr__('bias_ih_l' + str(i))
b_hh = self.__getattr__('bias_hh_l' + str(i))
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= reduce(operator.mul, input.shape[:2])
macs *= reduce(operator.mul, input.shape[:2])
if self.bidirectional:
flops *= 2
macs *= 2
return flops, macs
@meta_profiler_module.register(torch.nn.LSTMCell)
@meta_profiler_module.register(torch.nn.GRUCell)
@meta_profiler_module.register(torch.nn.RNNCell)
def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = 0
macs = 0
w_ih = self.__getattr__('weight_ih_l')
w_hh = self.__getattr__('weight_hh_l')
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
b_ih = self.__getattr__('bias_ih_l')
b_hh = self.__getattr__('bias_hh_l')
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= input.shape[0]
macs *= input.shape[0]
return flops, macs
import operator
import torch
from ..registry import meta_profiler_module
from typing import Optional, Tuple, Union
@meta_profiler_module.register(torch.nn.Flatten)
def torch_nn_flatten(self: torch.nn.Flatten, input: torch.Tensor) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs
class ProfilerRegistry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
self.store[source] = func
return func
return wrapper
def get(self, source):
assert source in self.store
target = self.store[source]
return target
def has(self, source):
return source in self.store
meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile')
meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile')
# for PyTorch 1.11 compatibility uses
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from ..._compatibility import compatibility
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_in(n: Node) -> bool:
"""A helper function to calculate `fwd_in`
Args:
n (Node): a node from the graph
Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
return n.meta['save_fwd_in']
@compatibility(is_backward_compatible=True)
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
Args:
n (Node): a node from the graph
Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
return n.meta["fwd_mem_tmp"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`
Args:
n (Node): a node from the graph
Returns:
fwd_out (int): the result of `fwd_out`
"""
return n.meta['fwd_mem_out']
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from .._compatibility import compatibility, is_compatible_with_meta
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`.
Returns:
int: The activation size, unit is byte.
"""
act_size = 0
if isinstance(out, torch.Tensor):
if out.is_quantized:
act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size()
else:
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`.
Returns:
int: The parameter size, unit is byte.
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw
import operator
from functools import partial, reduce
from numbers import Number
from typing import Any, Callable, List
import torch
aten = torch.ops.aten
def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [v.shape for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
return flops
def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for fully connected layers.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [v.shape for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [input feature dimension, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flops = batch_size * input_dim * output_dim
return flops
def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the aten::linear operator.
"""
# Inputs is a list of length 3; unlike aten::addmm, it is the first
# two elements that are relevant.
input_shapes = [v.shape for v in inputs[0:2]]
# input_shapes[0]: [dim0, dim1, ..., input_feature_dim]
# input_shapes[1]: [output_feature_dim, input_feature_dim]
assert input_shapes[0][-1] == input_shapes[1][-1]
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0]
return flops
def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the bmm operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert len(inputs) == 2, len(inputs)
input_shapes = [v.shape for v in inputs]
n, c, t = input_shapes[0]
d = input_shapes[-1][-1]
flops = n * c * t * d
return flops
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> Number:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape)
return flops
def conv_flop_jit(inputs: List[Any], outputs: List[Any]):
"""
Count flops for convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape)
transposed = inputs[6]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]):
grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count = 0
if output_mask[0]:
grad_input_shape = outputs[0].shape
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
if output_mask[1]:
grad_weight_shape = outputs[1].shape
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
return flop_count
def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
"""
Args:
affine_arg_index: index of the affine argument in inputs
"""
def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for norm layers.
"""
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
'shape') else inputs[affine_arg_index]
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
return flop
return norm_flop_jit
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number:
if training is None:
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable:
"""
Count flops by
input_tensor.numel() * input_scale + output_tensor.numel() * output_scale
Args:
input_scale: scale of the input tensor (first argument)
output_scale: scale of the output tensor (first element in outputs)
"""
def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
ret = 0
if input_scale != 0:
shape = inputs[0].shape
ret += input_scale * reduce(operator.mul, shape) if shape else 0
if output_scale != 0:
shape = outputs[0].shape
ret += output_scale * reduce(operator.mul, shape) if shape else 0
return ret
return elementwise_flop
def zero_flop_jit(*args):
"""
Count flops for zero flop layers.
"""
return 0
flop_mapping = {
# gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
# convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
# pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
aten.avg_pool3d.default: elementwise_flop_counter(1, 0),
aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.max_pool1d.default: elementwise_flop_counter(1, 0),
aten.max_pool2d.default: elementwise_flop_counter(1, 0),
aten.max_pool3d.default: elementwise_flop_counter(1, 0),
aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0),
aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0),
aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1),
aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0),
aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1),
aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
aten.embedding.default: elementwise_flop_counter(1, 0),
}
elementwise_flop_aten = [
# basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
aten.div_.Tensor,
aten.div.Scalar,
aten.div_.Scalar,
aten.mul.Tensor,
aten.mul.Scalar,
aten.mul_.Tensor,
aten.neg.default,
aten.pow.Tensor_Scalar,
aten.rsub.Scalar,
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
# activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
aten.hardtanh.default,
aten.hardtanh_.default,
aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default,
aten.hardsigmoid.default,
aten.gelu.default,
aten.gelu_backward.default,
aten.silu.default,
aten.silu_.default,
aten.silu_backward.default,
aten.sigmoid.default,
aten.sigmoid_backward.default,
aten._softmax.default,
aten._softmax_backward_data.default,
aten.relu_.default,
aten.relu.default,
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
# dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
]
for op in elementwise_flop_aten:
flop_mapping[op] = elementwise_flop_counter(1, 0)
# TODO: this will be removed in future
zero_flop_aten = [
aten.as_strided.default,
aten.as_strided_.default,
aten.bernoulli_.float,
aten.cat.default,
aten.clone.default,
aten.copy_.default,
aten.detach.default,
aten.expand.default,
aten.empty_like.default,
aten.new_empty.default,
aten.new_empty_strided.default,
aten.ones_like.default,
aten._reshape_alias.default,
aten.select.int,
aten.select_backward.default,
aten.squeeze.dim,
aten.slice.Tensor,
aten.slice_backward.default,
aten.split.Tensor,
aten.permute.default,
aten.t.default,
aten.transpose.int,
aten._to_copy.default,
aten.unsqueeze.default,
aten.unbind.int,
aten._unsafe_view.default,
aten.view.default,
aten.where.self,
aten.zero_.default,
aten.zeros_like.default,
]
for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment