Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
import torch
from typing import Dict from typing import Dict
from torch.fx.node import Node, map_arg
import torch
from torch.fx.graph import Graph from torch.fx.graph import Graph
from torch.fx.node import Node, map_arg
def get_comm_size(prev_partition, next_partition): def get_comm_size(prev_partition, next_partition):
""" """
...@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition): ...@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition):
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes: for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes: if n.name in parent_node_names and n not in visited_nodes:
comm_size += n.meta['tensor_meta'].numel comm_size += n.meta["tensor_meta"].numel
visited_nodes.add(n) visited_nodes.add(n)
return comm_size return comm_size
...@@ -36,12 +38,12 @@ def get_leaf(graph: Graph): ...@@ -36,12 +38,12 @@ def get_leaf(graph: Graph):
""" """
input_nodes: Dict[Node, None] = {} input_nodes: Dict[Node, None] = {}
for node in graph.nodes: for node in graph.nodes:
if node.op == 'output': if node.op == "output":
map_arg(node.args, lambda n: input_nodes.setdefault(n)) map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
placeholder_nodes = [] placeholder_nodes = []
for node in input_nodes.keys(): for node in input_nodes.keys():
if node.op == 'placeholder': if node.op == "placeholder":
placeholder_nodes.append(node) placeholder_nodes.append(node)
for node in placeholder_nodes: for node in placeholder_nodes:
input_nodes.pop(node) input_nodes.pop(node)
...@@ -60,13 +62,13 @@ def get_top(graph: Graph): ...@@ -60,13 +62,13 @@ def get_top(graph: Graph):
""" """
top_node_list = set() top_node_list = set()
for node in graph.nodes: for node in graph.nodes:
if node.op == 'output': if node.op == "output":
continue continue
is_top = False is_top = False
def _get_top(node): def _get_top(node):
nonlocal is_top nonlocal is_top
if node.op == 'placeholder': if node.op == "placeholder":
is_top = True is_top = True
map_arg(node.args, lambda n: _get_top(n)) map_arg(node.args, lambda n: _get_top(n))
...@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node): ...@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node):
def get_all_consumers(graph: Graph, node: Node): def get_all_consumers(graph: Graph, node: Node):
""" """
Given a graph and a node of this graph, return all consumers of the node. Given a graph and a node of this graph, return all consumers of the node.
Returns: Returns:
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
""" """
...@@ -120,7 +122,7 @@ def assign_bfs_level_to_nodes(graph: Graph): ...@@ -120,7 +122,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
for node in gm.graph.nodes: for node in gm.graph.nodes:
if hasattr(node, 'bfs_level'): if hasattr(node, 'bfs_level'):
print(node.name, node.bfs_level) print(node.name, node.bfs_level)
Output: Output:
graph(): graph():
%x : [#users=2] = placeholder[target=x] %x : [#users=2] = placeholder[target=x]
...@@ -148,7 +150,7 @@ def assign_bfs_level_to_nodes(graph: Graph): ...@@ -148,7 +150,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
while nodes_to_process: while nodes_to_process:
new_process_list = [] new_process_list = []
for node in nodes_to_process: for node in nodes_to_process:
if node.op == 'output': if node.op == "output":
continue continue
node.bfs_level = current_level node.bfs_level = current_level
new_process_list.extend(get_all_consumers(graph, node)) new_process_list.extend(get_all_consumers(graph, node))
...@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module: ...@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module:
torch.nn.Module: the module associated with the given node torch.nn.Module: the module associated with the given node
""" """
assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object' assert (
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' node.graph.owning_module is not None
), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object"
assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}"
module = node.graph.owning_module.get_submodule(node.target) module = node.graph.owning_module.get_submodule(node.target)
return module return module
...@@ -12,7 +12,16 @@ if is_compatible_with_meta(): ...@@ -12,7 +12,16 @@ if is_compatible_with_meta():
) )
from .tensor import MetaTensor from .tensor import MetaTensor
else: else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .experimental import (
meta_profiler_function,
meta_profiler_module,
profile_function,
profile_method,
profile_module,
calculate_fwd_in,
calculate_fwd_tmp,
calculate_fwd_out,
)
from .dataflow import GraphInfo from .dataflow import GraphInfo
from .memory_utils import activation_size, is_inplace, parameter_size from .memory_utils import activation_size, is_inplace, parameter_size
import torch import torch
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD'] __all__ = ["ALIAS_ATEN", "INPLACE_NEW", "INPLACE_MATH_ATEN", "CLONE_ATEN", "RELU_LIKE_OPS", "RELU_LIKE_MOD"]
aten = torch.ops.aten aten = torch.ops.aten
......
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from functools import partial
from typing import Dict, List from typing import Dict, List
from torch.fx import Graph, Node from torch.fx import Graph, Node
...@@ -69,8 +68,8 @@ class GraphInfo: ...@@ -69,8 +68,8 @@ class GraphInfo:
def is_phase(n: Node, phase: Phase) -> bool: def is_phase(n: Node, phase: Phase) -> bool:
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!"
return n.meta['phase'] == phase return n.meta["phase"] == phase
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
...@@ -103,9 +102,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: ...@@ -103,9 +102,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
peak_mem = 0 peak_mem = 0
for k, v in deps.items(): for k, v in deps.items():
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k): if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
peak_mem += activation_size(k.meta['saved_tensor']) peak_mem += activation_size(k.meta["saved_tensor"])
if v <= float('-inf') and is_phase(k, Phase.FORWARD): if v <= float("-inf") and is_phase(k, Phase.FORWARD):
peak_mem -= activation_size(k.meta['saved_tensor']) peak_mem -= activation_size(k.meta["saved_tensor"])
return peak_mem return peak_mem
# deps is used to track all the memory dependencies of the graph. # deps is used to track all the memory dependencies of the graph.
...@@ -123,19 +122,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: ...@@ -123,19 +122,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed. # the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER): if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_in += n.meta['saved_tensor'] graph_info.fwd_in += n.meta["saved_tensor"]
if is_phase(n, Phase.FORWARD): if is_phase(n, Phase.FORWARD):
graph_info.fwd_tmp += n.meta['saved_tensor'] graph_info.fwd_tmp += n.meta["saved_tensor"]
elif is_phase(n, Phase.BACKWARD): elif is_phase(n, Phase.BACKWARD):
if len(n.users): if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else: else:
# TODO: some of the bwd_mem_out might be model parameters. # TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node # basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor']) graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"])
for input_n in n.all_input_nodes: for input_n in n.all_input_nodes:
if input_n in deps: if input_n in deps:
deps[input_n] -= 1 deps[input_n] -= 1
if deps[input_n] <= 0: if deps[input_n] <= 0:
deps[input_n] = float('-inf') deps[input_n] = float("-inf")
return graph_info return graph_info
...@@ -2,7 +2,7 @@ from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub ...@@ -2,7 +2,7 @@ from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub
import torch import torch
__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] __all__ = ["INPLACE_OPS", "INPLACE_METHOD", "NON_INPLACE_METHOD"]
# TODO fill out the inplace ops # TODO fill out the inplace ops
INPLACE_OPS = [ INPLACE_OPS = [
...@@ -20,25 +20,25 @@ INPLACE_OPS = [ ...@@ -20,25 +20,25 @@ INPLACE_OPS = [
# TODO: list all call_methods that are inplace here # TODO: list all call_methods that are inplace here
INPLACE_METHOD = [ INPLACE_METHOD = [
'transpose', "transpose",
'permute', "permute",
# TODO: reshape may return a copy of the data if the data is not contiguous # TODO: reshape may return a copy of the data if the data is not contiguous
'reshape', "reshape",
'dim', "dim",
'flatten', "flatten",
'size', "size",
'view', "view",
'unsqueeze', "unsqueeze",
'to', "to",
'type', "type",
'flatten', "flatten",
] ]
# TODO: list all call_methods that are not inplace here # TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [ NON_INPLACE_METHOD = [
'chunk', "chunk",
'contiguous', "contiguous",
'expand', "expand",
'mean', "mean",
'split', "split",
] ]
...@@ -9,7 +9,7 @@ from ..memory_utils import activation_size ...@@ -9,7 +9,7 @@ from ..memory_utils import activation_size
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module from .registry import meta_profiler_function, meta_profiler_module
__all__ = ['profile_function', 'profile_module', 'profile_method'] __all__ = ["profile_function", "profile_module", "profile_method"]
# this is for compatibility use # this is for compatibility use
...@@ -42,6 +42,7 @@ class GraphInfo: ...@@ -42,6 +42,7 @@ class GraphInfo:
bwd_mem_tmp (int): See the above illustration. bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration. bwd_mem_out (int): See the above illustration.
""" """
fwd_flop: int = 0 fwd_flop: int = 0
bwd_flop: int = 0 bwd_flop: int = 0
fwd_mem_in: int = 0 fwd_mem_in: int = 0
...@@ -50,8 +51,7 @@ class GraphInfo: ...@@ -50,8 +51,7 @@ class GraphInfo:
bwd_mem_out: int = 0 bwd_mem_out: int = 0
CALL_FUNCTION_MSG = \ CALL_FUNCTION_MSG = """
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION) @meta_profiler_function.register(YOUR_FUNCTION)
...@@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: ...@@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
macs = ... macs = ...
return flops, macs return flops, macs
""" """
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}"
CALL_MODULE_MSG = \ CALL_MODULE_MSG = """
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE) @meta_profiler_module.register(YOUR_MODULE)
...@@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int ...@@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable: def profile_function(target: "Target") -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
...@@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable: ...@@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has( assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target) target.__name__
), CALL_FUNCTION_MSG.format(target)
fwd_tmp = 0 fwd_tmp = 0
fwd_out = 0 fwd_out = 0
out = func(*args, **kwargs) out = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False): if target not in INPLACE_OPS and not kwargs.get("inplace", False):
fwd_out = activation_size(out) fwd_out = activation_size(out)
if meta_profiler_function.has(target): if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target) profiler = meta_profiler_function.get(target)
...@@ -112,7 +112,7 @@ def profile_function(target: 'Target') -> Callable: ...@@ -112,7 +112,7 @@ def profile_function(target: 'Target') -> Callable:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_method(target: 'Target') -> Callable: def profile_method(target: "Target") -> Callable:
""" """
Wrap a `call_method` node Wrap a `call_method` node
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
...@@ -126,11 +126,12 @@ def profile_method(target: 'Target') -> Callable: ...@@ -126,11 +126,12 @@ def profile_method(target: 'Target') -> Callable:
self_obj, *args_tail = args self_obj, *args_tail = args
# execute the method and return the result # execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.' assert isinstance(target, str), f"{target} instance is not str."
out = getattr(self_obj, target)(*args_tail, **kwargs) out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD) target, INPLACE_METHOD, NON_INPLACE_METHOD
)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
...@@ -161,7 +162,7 @@ def profile_module(module: torch.nn.Module) -> Callable: ...@@ -161,7 +162,7 @@ def profile_module(module: torch.nn.Module) -> Callable:
fwd_tmp = 0 fwd_tmp = 0
fwd_out = 0 fwd_out = 0
out = func(*args, **kwargs) out = func(*args, **kwargs)
if getattr(module, 'inplace', False): if getattr(module, "inplace", False):
fwd_out = activation_size(out) fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module)) profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs) fwd_flop, _ = profiler(module, *args, **kwargs)
......
from typing import Tuple from typing import Tuple
import torch import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function
# TODO: different activation has different FLOPs count, currently unused. # TODO: different activation has different FLOPs count, currently unused.
......
...@@ -41,15 +41,15 @@ def _elementwise_flops_compute(input, other): ...@@ -41,15 +41,15 @@ def _elementwise_flops_compute(input, other):
@meta_profiler_function.register(torch.sub) @meta_profiler_function.register(torch.sub)
@meta_profiler_function.register(torch.mul) @meta_profiler_function.register(torch.mul)
@meta_profiler_function.register(torch.floor_divide) @meta_profiler_function.register(torch.floor_divide)
@meta_profiler_function.register('add') # for built-in op + @meta_profiler_function.register("add") # for built-in op +
@meta_profiler_function.register('iadd') # for built-in op += @meta_profiler_function.register("iadd") # for built-in op +=
@meta_profiler_function.register('eq') # for built-in op = @meta_profiler_function.register("eq") # for built-in op =
@meta_profiler_function.register('sub') # for built-in op - @meta_profiler_function.register("sub") # for built-in op -
@meta_profiler_function.register('isub') # for built-in op -= @meta_profiler_function.register("isub") # for built-in op -=
@meta_profiler_function.register('mul') # for built-in op * @meta_profiler_function.register("mul") # for built-in op *
@meta_profiler_function.register('imul') # for built-in op *= @meta_profiler_function.register("imul") # for built-in op *=
@meta_profiler_function.register('floordiv') # for built-in op // @meta_profiler_function.register("floordiv") # for built-in op //
@meta_profiler_function.register('ifloordiv') # for built-in op //= @meta_profiler_function.register("ifloordiv") # for built-in op //=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other) return _elementwise_flops_compute(input, other)
...@@ -62,7 +62,7 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N ...@@ -62,7 +62,7 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N
@meta_profiler_function.register(torch.matmul) @meta_profiler_function.register(torch.matmul)
@meta_profiler_function.register('matmul') # for built-in op @ @meta_profiler_function.register("matmul") # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul) @meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = reduce(operator.mul, input.shape) * other.shape[-1] macs = reduce(operator.mul, input.shape) * other.shape[-1]
...@@ -78,13 +78,15 @@ def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.T ...@@ -78,13 +78,15 @@ def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.T
@meta_profiler_function.register(torch.var_mean) @meta_profiler_function.register(torch.var_mean)
def torch_var_mean(input: torch.Tensor, def torch_var_mean(
dim: Union[int, Tuple[int, ...]], input: torch.Tensor,
unbiased: Optional[bool] = True, dim: Union[int, Tuple[int, ...]],
keepdim: Optional[bool] = False, unbiased: Optional[bool] = True,
*, keepdim: Optional[bool] = False,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]: *,
assert out is None, 'saving to out is not supported yet' out: Optional[torch.Tensor] = None,
) -> Tuple[int, int]:
assert out is None, "saving to out is not supported yet"
flops = input.numel() * 3 flops = input.numel() * 3
macs = 0 macs = 0
return flops, macs return flops, macs
import torch
from typing import Optional from typing import Optional
import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function
......
from typing import Tuple from typing import Tuple
import torch import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function
......
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function
...@@ -21,11 +23,13 @@ def torch_nn_func_instancenorm( ...@@ -21,11 +23,13 @@ def torch_nn_func_instancenorm(
@meta_profiler_function.register(torch.nn.functional.group_norm) @meta_profiler_function.register(torch.nn.functional.group_norm)
def torch_nn_func_groupnorm(input: torch.Tensor, def torch_nn_func_groupnorm(
num_groups: int, input: torch.Tensor,
weight: Optional[torch.Tensor] = None, num_groups: int,
bias: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None,
eps: float = 1e-5) -> Tuple[int, int]: bias: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> Tuple[int, int]:
has_affine = weight is not None has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4) flops = input.numel() * (5 if has_affine else 4)
macs = 0 macs = 0
......
from typing import Tuple, Union from typing import Tuple
import torch import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function
......
import operator import operator
from typing import Any, Tuple from typing import Any, Tuple
import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function
......
from functools import reduce
import operator import operator
from functools import reduce
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
import torch import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function
...@@ -43,13 +45,11 @@ def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]: ...@@ -43,13 +45,11 @@ def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
@meta_profiler_function.register(torch.max) @meta_profiler_function.register(torch.max)
def torch_max(input: torch.Tensor, def torch_max(
dim: int = None, input: torch.Tensor, dim: int = None, keepdim: bool = False, *, out: Optional[torch.Tensor] = None
keepdim: bool = False, ) -> Tuple[int, int]:
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = 0 macs = 0
assert out is None, 'assigning value to out is not supported yet' assert out is None, "assigning value to out is not supported yet"
if dim is not None: if dim is not None:
shape = list(input.shape) shape = list(input.shape)
shape.pop(int(dim)) shape.pop(int(dim))
......
from typing import Tuple from typing import Tuple
import torch import torch
from ..registry import meta_profiler_module from ..registry import meta_profiler_module
# TODO: different activation has different FLOPs count, currently unused. # TODO: different activation has different FLOPs count, currently unused.
......
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from ..registry import meta_profiler_module from ..registry import meta_profiler_module
# TODO: This is hard to compute memory cost # TODO: This is hard to compute memory cost
@meta_profiler_module.register(torch.nn.MultiheadAttention) @meta_profiler_module.register(torch.nn.MultiheadAttention)
def torch_nn_msa(self: torch.nn.MultiheadAttention, def torch_nn_msa(
query: torch.Tensor, self: torch.nn.MultiheadAttention,
key: torch.Tensor, query: torch.Tensor,
value: torch.Tensor, key: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None, value: torch.Tensor,
need_weights: bool = True, key_padding_mask: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None, need_weights: bool = True,
average_attn_weights: bool = True) -> Tuple[int, int]: attn_mask: Optional[torch.Tensor] = None,
if getattr(self, 'batch_first', False): average_attn_weights: bool = True,
) -> Tuple[int, int]:
if getattr(self, "batch_first", False):
batch_size = query.shape[0] batch_size = query.shape[0]
len_idx = 1 len_idx = 1
else: else:
...@@ -44,15 +48,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention, ...@@ -44,15 +48,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
flops += qlen * qdim flops += qlen * qdim
# Initial projections # Initial projections
flops += 2 * ((qlen * qdim * qdim) # QW flops += 2 * ((qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim)) # QW # KW # VW
+ (klen * kdim * kdim) # KW
+ (vlen * vdim * vdim) # VW
)
macs += ((qlen * qdim * qdim) # QW macs += (qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim) # QW # KW # VW
+ (klen * kdim * kdim) # KW
+ (vlen * vdim * vdim) # VW
)
if self.in_proj_bias is not None: if self.in_proj_bias is not None:
flops += (qlen + klen + vlen) * qdim flops += (qlen + klen + vlen) * qdim
...@@ -62,13 +60,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention, ...@@ -62,13 +60,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
v_head_dim = vdim // num_heads v_head_dim = vdim // num_heads
head_flops = ( head_flops = (
2 * (qlen * klen * qk_head_dim) # QK^T 2 * (qlen * klen * qk_head_dim) + (qlen * klen) + 2 * (qlen * klen * v_head_dim) # QK^T # softmax # AV
+ (qlen * klen) # softmax
+ 2 * (qlen * klen * v_head_dim) # AV
) )
head_macs = ((qlen * klen * qk_head_dim) # QK^T head_macs = (qlen * klen * qk_head_dim) + 2 * (qlen * klen * v_head_dim) # QK^T # AV
+ 2 * (qlen * klen * v_head_dim) # AV
)
flops += num_heads * head_flops flops += num_heads * head_flops
macs += num_heads * head_flops macs += num_heads * head_flops
......
...@@ -17,8 +17,9 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in ...@@ -17,8 +17,9 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
c_in, l_in = input.shape[-2:] c_in, l_in = input.shape[-2:]
c_out = self.out_channels c_out = self.out_channels
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * l_out = math.floor(
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
result_shape = input.shape[:-2] + ( result_shape = input.shape[:-2] + (
c_out, c_out,
l_out, l_out,
...@@ -38,10 +39,12 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in ...@@ -38,10 +39,12 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html # at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
c_in, h_in, w_in = input.shape[-3:] c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels c_out = self.out_channels
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * h_out = math.floor(
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * )
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) w_out = math.floor(
(w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
result_shape = input.shape[:-3] + ( result_shape = input.shape[:-3] + (
c_out, c_out,
h_out, h_out,
...@@ -62,12 +65,15 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in ...@@ -62,12 +65,15 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html # at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
c_in, d_in, h_in, w_in = input.shape[-4:] c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels c_out = self.out_channels
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * d_out = math.floor(
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * )
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) h_out = math.floor(
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) )
w_out = math.floor(
(w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1
)
result_shape = input.shape[:-4] + ( result_shape = input.shape[:-4] + (
c_out, c_out,
d_out, d_out,
...@@ -89,8 +95,13 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor ...@@ -89,8 +95,13 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
c_in, l_in = input.shape[-2:] c_in, l_in = input.shape[-2:]
c_out = self.out_channels c_out = self.out_channels
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * l_out = math.floor(
(self.kernel_size[0] - 1) + self.output_padding[0] + 1) (l_in - 1) * self.stride[0]
- 2 * self.padding[0]
+ self.dilation[0] * (self.kernel_size[0] - 1)
+ self.output_padding[0]
+ 1
)
result_shape = input.shape[:-2] + ( result_shape = input.shape[:-2] + (
c_out, c_out,
l_out, l_out,
...@@ -98,7 +109,7 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor ...@@ -98,7 +109,7 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce( num_elem = reduce(
operator.mul, input.shape operator.mul, input.shape
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs = macs_per_elem * num_elem macs = macs_per_elem * num_elem
flops = 2 * macs flops = 2 * macs
if self.bias is not None: if self.bias is not None:
...@@ -112,10 +123,20 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor ...@@ -112,10 +123,20 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
c_in, h_in, w_in = input.shape[-3:] c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels c_out = self.out_channels
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * h_out = math.floor(
(self.kernel_size[0] - 1) + self.output_padding[0] + 1) (h_in - 1) * self.stride[0]
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - 2 * self.padding[0]
(self.kernel_size[1] - 1) + self.output_padding[1] + 1) + self.dilation[0] * (self.kernel_size[0] - 1)
+ self.output_padding[0]
+ 1
)
w_out = math.floor(
(w_in - 1) * self.stride[1]
- 2 * self.padding[1]
+ self.dilation[1] * (self.kernel_size[1] - 1)
+ self.output_padding[1]
+ 1
)
result_shape = input.shape[:-3] + ( result_shape = input.shape[:-3] + (
c_out, c_out,
h_out, h_out,
...@@ -136,12 +157,27 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor ...@@ -136,12 +157,27 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
c_in, d_in, h_in, w_in = input.shape[-4:] c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels c_out = self.out_channels
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * d_out = math.floor(
(self.kernel_size[0] - 1) + self.output_padding[0] + 1) (d_in - 1) * self.stride[0]
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - 2 * self.padding[0]
(self.kernel_size[1] - 1) + self.output_padding[1] + 1) + self.dilation[0] * (self.kernel_size[0] - 1)
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * + self.output_padding[0]
(self.kernel_size[2] - 1) + self.output_padding[2] + 1) + 1
)
h_out = math.floor(
(h_in - 1) * self.stride[1]
- 2 * self.padding[1]
+ self.dilation[1] * (self.kernel_size[1] - 1)
+ self.output_padding[1]
+ 1
)
w_out = math.floor(
(w_in - 1) * self.stride[2]
- 2 * self.padding[2]
+ self.dilation[2] * (self.kernel_size[2] - 1)
+ self.output_padding[2]
+ 1
)
result_shape = input.shape[:-4] + ( result_shape = input.shape[:-4] + (
c_out, c_out,
d_out, d_out,
......
from typing import Tuple from typing import Tuple
import torch import torch
from ..registry import meta_profiler_module from ..registry import meta_profiler_module
......
from typing import Tuple from typing import Tuple
import torch import torch
from ..registry import meta_profiler_module from ..registry import meta_profiler_module
......
...@@ -16,8 +16,12 @@ from ..registry import meta_profiler_module ...@@ -16,8 +16,12 @@ from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.BatchNorm1d) @meta_profiler_module.register(torch.nn.BatchNorm1d)
@meta_profiler_module.register(torch.nn.BatchNorm2d) @meta_profiler_module.register(torch.nn.BatchNorm2d)
@meta_profiler_module.register(torch.nn.BatchNorm3d) @meta_profiler_module.register(torch.nn.BatchNorm3d)
def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, def torch_nn_normalize(
torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]: self: Union[
torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d
],
input: torch.Tensor,
) -> Tuple[int, int]:
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615 # adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
has_affine = self.weight is not None has_affine = self.weight is not None
if self.training: if self.training:
...@@ -30,6 +34,7 @@ def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch ...@@ -30,6 +34,7 @@ def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch
try: try:
import apex import apex
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment