Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
import torch
from typing import Dict
from torch.fx.node import Node, map_arg
import torch
from torch.fx.graph import Graph
from torch.fx.node import Node, map_arg
def get_comm_size(prev_partition, next_partition):
"""
......@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition):
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
comm_size += n.meta['tensor_meta'].numel
comm_size += n.meta["tensor_meta"].numel
visited_nodes.add(n)
return comm_size
......@@ -36,12 +38,12 @@ def get_leaf(graph: Graph):
"""
input_nodes: Dict[Node, None] = {}
for node in graph.nodes:
if node.op == 'output':
if node.op == "output":
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
placeholder_nodes = []
for node in input_nodes.keys():
if node.op == 'placeholder':
if node.op == "placeholder":
placeholder_nodes.append(node)
for node in placeholder_nodes:
input_nodes.pop(node)
......@@ -60,13 +62,13 @@ def get_top(graph: Graph):
"""
top_node_list = set()
for node in graph.nodes:
if node.op == 'output':
if node.op == "output":
continue
is_top = False
def _get_top(node):
nonlocal is_top
if node.op == 'placeholder':
if node.op == "placeholder":
is_top = True
map_arg(node.args, lambda n: _get_top(n))
......@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node):
def get_all_consumers(graph: Graph, node: Node):
"""
Given a graph and a node of this graph, return all consumers of the node.
Returns:
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
"""
......@@ -120,7 +122,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
for node in gm.graph.nodes:
if hasattr(node, 'bfs_level'):
print(node.name, node.bfs_level)
Output:
graph():
%x : [#users=2] = placeholder[target=x]
......@@ -148,7 +150,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
while nodes_to_process:
new_process_list = []
for node in nodes_to_process:
if node.op == 'output':
if node.op == "output":
continue
node.bfs_level = current_level
new_process_list.extend(get_all_consumers(graph, node))
......@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module:
torch.nn.Module: the module associated with the given node
"""
assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
assert (
node.graph.owning_module is not None
), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object"
assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}"
module = node.graph.owning_module.get_submodule(node.target)
return module
......@@ -12,7 +12,16 @@ if is_compatible_with_meta():
)
from .tensor import MetaTensor
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .experimental import (
meta_profiler_function,
meta_profiler_module,
profile_function,
profile_method,
profile_module,
calculate_fwd_in,
calculate_fwd_tmp,
calculate_fwd_out,
)
from .dataflow import GraphInfo
from .memory_utils import activation_size, is_inplace, parameter_size
import torch
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD']
__all__ = ["ALIAS_ATEN", "INPLACE_NEW", "INPLACE_MATH_ATEN", "CLONE_ATEN", "RELU_LIKE_OPS", "RELU_LIKE_MOD"]
aten = torch.ops.aten
......
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from typing import Dict, List
from torch.fx import Graph, Node
......@@ -69,8 +68,8 @@ class GraphInfo:
def is_phase(n: Node, phase: Phase) -> bool:
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['phase'] == phase
assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!"
return n.meta["phase"] == phase
@compatibility(is_backward_compatible=False)
......@@ -103,9 +102,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
peak_mem = 0
for k, v in deps.items():
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
peak_mem += activation_size(k.meta['saved_tensor'])
if v <= float('-inf') and is_phase(k, Phase.FORWARD):
peak_mem -= activation_size(k.meta['saved_tensor'])
peak_mem += activation_size(k.meta["saved_tensor"])
if v <= float("-inf") and is_phase(k, Phase.FORWARD):
peak_mem -= activation_size(k.meta["saved_tensor"])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
......@@ -123,19 +122,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_in += n.meta['saved_tensor']
graph_info.fwd_in += n.meta["saved_tensor"]
if is_phase(n, Phase.FORWARD):
graph_info.fwd_tmp += n.meta['saved_tensor']
graph_info.fwd_tmp += n.meta["saved_tensor"]
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
if deps[input_n] <= 0:
deps[input_n] = float('-inf')
deps[input_n] = float("-inf")
return graph_info
......@@ -2,7 +2,7 @@ from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub
import torch
__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
__all__ = ["INPLACE_OPS", "INPLACE_METHOD", "NON_INPLACE_METHOD"]
# TODO fill out the inplace ops
INPLACE_OPS = [
......@@ -20,25 +20,25 @@ INPLACE_OPS = [
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
"transpose",
"permute",
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
"reshape",
"dim",
"flatten",
"size",
"view",
"unsqueeze",
"to",
"type",
"flatten",
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
"chunk",
"contiguous",
"expand",
"mean",
"split",
]
......@@ -9,7 +9,7 @@ from ..memory_utils import activation_size
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
__all__ = ['profile_function', 'profile_module', 'profile_method']
__all__ = ["profile_function", "profile_module", "profile_method"]
# this is for compatibility use
......@@ -42,6 +42,7 @@ class GraphInfo:
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
......@@ -50,8 +51,7 @@ class GraphInfo:
bwd_mem_out: int = 0
CALL_FUNCTION_MSG = \
"""
CALL_FUNCTION_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
......@@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}"
CALL_MODULE_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
......@@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable:
def profile_function(target: "Target") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
......@@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target)
target.__name__
), CALL_FUNCTION_MSG.format(target)
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
if target not in INPLACE_OPS and not kwargs.get("inplace", False):
fwd_out = activation_size(out)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
......@@ -112,7 +112,7 @@ def profile_function(target: 'Target') -> Callable:
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target') -> Callable:
def profile_method(target: "Target") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
......@@ -126,11 +126,12 @@ def profile_method(target: 'Target') -> Callable:
self_obj, *args_tail = args
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
assert isinstance(target, str), f"{target} instance is not str."
out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
target, INPLACE_METHOD, NON_INPLACE_METHOD
)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
......@@ -161,7 +162,7 @@ def profile_module(module: torch.nn.Module) -> Callable:
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if getattr(module, 'inplace', False):
if getattr(module, "inplace", False):
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
......
from typing import Tuple
import torch
from ..registry import meta_profiler_function
# TODO: different activation has different FLOPs count, currently unused.
......
......@@ -41,15 +41,15 @@ def _elementwise_flops_compute(input, other):
@meta_profiler_function.register(torch.sub)
@meta_profiler_function.register(torch.mul)
@meta_profiler_function.register(torch.floor_divide)
@meta_profiler_function.register('add') # for built-in op +
@meta_profiler_function.register('iadd') # for built-in op +=
@meta_profiler_function.register('eq') # for built-in op =
@meta_profiler_function.register('sub') # for built-in op -
@meta_profiler_function.register('isub') # for built-in op -=
@meta_profiler_function.register('mul') # for built-in op *
@meta_profiler_function.register('imul') # for built-in op *=
@meta_profiler_function.register('floordiv') # for built-in op //
@meta_profiler_function.register('ifloordiv') # for built-in op //=
@meta_profiler_function.register("add") # for built-in op +
@meta_profiler_function.register("iadd") # for built-in op +=
@meta_profiler_function.register("eq") # for built-in op =
@meta_profiler_function.register("sub") # for built-in op -
@meta_profiler_function.register("isub") # for built-in op -=
@meta_profiler_function.register("mul") # for built-in op *
@meta_profiler_function.register("imul") # for built-in op *=
@meta_profiler_function.register("floordiv") # for built-in op //
@meta_profiler_function.register("ifloordiv") # for built-in op //=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)
......@@ -62,7 +62,7 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N
@meta_profiler_function.register(torch.matmul)
@meta_profiler_function.register('matmul') # for built-in op @
@meta_profiler_function.register("matmul") # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = reduce(operator.mul, input.shape) * other.shape[-1]
......@@ -78,13 +78,15 @@ def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.T
@meta_profiler_function.register(torch.var_mean)
def torch_var_mean(input: torch.Tensor,
dim: Union[int, Tuple[int, ...]],
unbiased: Optional[bool] = True,
keepdim: Optional[bool] = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
assert out is None, 'saving to out is not supported yet'
def torch_var_mean(
input: torch.Tensor,
dim: Union[int, Tuple[int, ...]],
unbiased: Optional[bool] = True,
keepdim: Optional[bool] = False,
*,
out: Optional[torch.Tensor] = None,
) -> Tuple[int, int]:
assert out is None, "saving to out is not supported yet"
flops = input.numel() * 3
macs = 0
return flops, macs
import torch
from typing import Optional
import torch
from ..registry import meta_profiler_function
......
from typing import Tuple
import torch
from ..registry import meta_profiler_function
......
from typing import List, Optional, Tuple
import torch
from ..registry import meta_profiler_function
......@@ -21,11 +23,13 @@ def torch_nn_func_instancenorm(
@meta_profiler_function.register(torch.nn.functional.group_norm)
def torch_nn_func_groupnorm(input: torch.Tensor,
num_groups: int,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5) -> Tuple[int, int]:
def torch_nn_func_groupnorm(
input: torch.Tensor,
num_groups: int,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
......
from typing import Tuple, Union
from typing import Tuple
import torch
from ..registry import meta_profiler_function
......
import operator
from typing import Any, Tuple
import torch
from ..registry import meta_profiler_function
......
from functools import reduce
import operator
from functools import reduce
from typing import Any, Optional, Tuple
import torch
from ..registry import meta_profiler_function
......@@ -43,13 +45,11 @@ def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
@meta_profiler_function.register(torch.max)
def torch_max(input: torch.Tensor,
dim: int = None,
keepdim: bool = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
def torch_max(
input: torch.Tensor, dim: int = None, keepdim: bool = False, *, out: Optional[torch.Tensor] = None
) -> Tuple[int, int]:
macs = 0
assert out is None, 'assigning value to out is not supported yet'
assert out is None, "assigning value to out is not supported yet"
if dim is not None:
shape = list(input.shape)
shape.pop(int(dim))
......
from typing import Tuple
import torch
from ..registry import meta_profiler_module
# TODO: different activation has different FLOPs count, currently unused.
......
from typing import Optional, Tuple
import torch
from ..registry import meta_profiler_module
# TODO: This is hard to compute memory cost
@meta_profiler_module.register(torch.nn.MultiheadAttention)
def torch_nn_msa(self: torch.nn.MultiheadAttention,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[torch.Tensor] = None,
average_attn_weights: bool = True) -> Tuple[int, int]:
if getattr(self, 'batch_first', False):
def torch_nn_msa(
self: torch.nn.MultiheadAttention,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[torch.Tensor] = None,
average_attn_weights: bool = True,
) -> Tuple[int, int]:
if getattr(self, "batch_first", False):
batch_size = query.shape[0]
len_idx = 1
else:
......@@ -44,15 +48,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
flops += qlen * qdim
# Initial projections
flops += 2 * ((qlen * qdim * qdim) # QW
+ (klen * kdim * kdim) # KW
+ (vlen * vdim * vdim) # VW
)
flops += 2 * ((qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim)) # QW # KW # VW
macs += ((qlen * qdim * qdim) # QW
+ (klen * kdim * kdim) # KW
+ (vlen * vdim * vdim) # VW
)
macs += (qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim) # QW # KW # VW
if self.in_proj_bias is not None:
flops += (qlen + klen + vlen) * qdim
......@@ -62,13 +60,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
v_head_dim = vdim // num_heads
head_flops = (
2 * (qlen * klen * qk_head_dim) # QK^T
+ (qlen * klen) # softmax
+ 2 * (qlen * klen * v_head_dim) # AV
2 * (qlen * klen * qk_head_dim) + (qlen * klen) + 2 * (qlen * klen * v_head_dim) # QK^T # softmax # AV
)
head_macs = ((qlen * klen * qk_head_dim) # QK^T
+ 2 * (qlen * klen * v_head_dim) # AV
)
head_macs = (qlen * klen * qk_head_dim) + 2 * (qlen * klen * v_head_dim) # QK^T # AV
flops += num_heads * head_flops
macs += num_heads * head_flops
......
......@@ -17,8 +17,9 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
l_out = math.floor(
(l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
result_shape = input.shape[:-2] + (
c_out,
l_out,
......@@ -38,10 +39,12 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
h_out = math.floor(
(h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
w_out = math.floor(
(w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
result_shape = input.shape[:-3] + (
c_out,
h_out,
......@@ -62,12 +65,15 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
d_out = math.floor(
(d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
h_out = math.floor(
(h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
w_out = math.floor(
(w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1
)
result_shape = input.shape[:-4] + (
c_out,
d_out,
......@@ -89,8 +95,13 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
l_out = math.floor(
(l_in - 1) * self.stride[0]
- 2 * self.padding[0]
+ self.dilation[0] * (self.kernel_size[0] - 1)
+ self.output_padding[0]
+ 1
)
result_shape = input.shape[:-2] + (
c_out,
l_out,
......@@ -98,7 +109,7 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(
operator.mul, input.shape
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
......@@ -112,10 +123,20 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
h_out = math.floor(
(h_in - 1) * self.stride[0]
- 2 * self.padding[0]
+ self.dilation[0] * (self.kernel_size[0] - 1)
+ self.output_padding[0]
+ 1
)
w_out = math.floor(
(w_in - 1) * self.stride[1]
- 2 * self.padding[1]
+ self.dilation[1] * (self.kernel_size[1] - 1)
+ self.output_padding[1]
+ 1
)
result_shape = input.shape[:-3] + (
c_out,
h_out,
......@@ -136,12 +157,27 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
(self.kernel_size[2] - 1) + self.output_padding[2] + 1)
d_out = math.floor(
(d_in - 1) * self.stride[0]
- 2 * self.padding[0]
+ self.dilation[0] * (self.kernel_size[0] - 1)
+ self.output_padding[0]
+ 1
)
h_out = math.floor(
(h_in - 1) * self.stride[1]
- 2 * self.padding[1]
+ self.dilation[1] * (self.kernel_size[1] - 1)
+ self.output_padding[1]
+ 1
)
w_out = math.floor(
(w_in - 1) * self.stride[2]
- 2 * self.padding[2]
+ self.dilation[2] * (self.kernel_size[2] - 1)
+ self.output_padding[2]
+ 1
)
result_shape = input.shape[:-4] + (
c_out,
d_out,
......
from typing import Tuple
import torch
from ..registry import meta_profiler_module
......
from typing import Tuple
import torch
from ..registry import meta_profiler_module
......
......@@ -16,8 +16,12 @@ from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.BatchNorm1d)
@meta_profiler_module.register(torch.nn.BatchNorm2d)
@meta_profiler_module.register(torch.nn.BatchNorm3d)
def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]:
def torch_nn_normalize(
self: Union[
torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d
],
input: torch.Tensor,
) -> Tuple[int, int]:
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
has_affine = self.weight is not None
if self.training:
......@@ -30,6 +34,7 @@ def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch
try:
import apex
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment