Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
from typing import Tuple from typing import Tuple
import torch import torch
from ..registry import meta_profiler_module from ..registry import meta_profiler_module
......
from functools import reduce
import operator import operator
from functools import reduce
from typing import Optional, Tuple
import torch import torch
from ..registry import meta_profiler_module from ..registry import meta_profiler_module
from typing import Optional, Tuple, Union
def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, def _rnn_flops(
w_hh: torch.Tensor) -> Tuple[int, int]: flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, w_hh: torch.Tensor
) -> Tuple[int, int]:
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py # copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
# matrix matrix mult ih state and internal state # matrix matrix mult ih state and internal state
...@@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch ...@@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
flops = 0 flops = 0
macs = 0 macs = 0
for i in range(self.num_layers): for i in range(self.num_layers):
w_ih = self.__getattr__('weight_ih_l' + str(i)) w_ih = self.__getattr__("weight_ih_l" + str(i))
w_hh = self.__getattr__('weight_hh_l' + str(i)) w_hh = self.__getattr__("weight_hh_l" + str(i))
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias: if self.bias:
b_ih = self.__getattr__('bias_ih_l' + str(i)) b_ih = self.__getattr__("bias_ih_l" + str(i))
b_hh = self.__getattr__('bias_hh_l' + str(i)) b_hh = self.__getattr__("bias_hh_l" + str(i))
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= reduce(operator.mul, input.shape[:2]) flops *= reduce(operator.mul, input.shape[:2])
macs *= reduce(operator.mul, input.shape[:2]) macs *= reduce(operator.mul, input.shape[:2])
...@@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch ...@@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = 0 flops = 0
macs = 0 macs = 0
w_ih = self.__getattr__('weight_ih_l') w_ih = self.__getattr__("weight_ih_l")
w_hh = self.__getattr__('weight_hh_l') w_hh = self.__getattr__("weight_hh_l")
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias: if self.bias:
b_ih = self.__getattr__('bias_ih_l') b_ih = self.__getattr__("bias_ih_l")
b_hh = self.__getattr__('bias_hh_l') b_hh = self.__getattr__("bias_hh_l")
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= input.shape[0] flops *= input.shape[0]
macs *= input.shape[0] macs *= input.shape[0]
......
import operator from typing import Tuple
import torch import torch
from ..registry import meta_profiler_module from ..registry import meta_profiler_module
from typing import Optional, Tuple, Union
@meta_profiler_module.register(torch.nn.Flatten) @meta_profiler_module.register(torch.nn.Flatten)
......
class ProfilerRegistry: class ProfilerRegistry:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.store = {} self.store = {}
def register(self, source): def register(self, source):
def wrapper(func): def wrapper(func):
self.store[source] = func self.store[source] = func
return func return func
...@@ -21,5 +19,5 @@ class ProfilerRegistry: ...@@ -21,5 +19,5 @@ class ProfilerRegistry:
return source in self.store return source in self.store
meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile') meta_profiler_function = ProfilerRegistry(name="patched_functions_for_meta_profile")
meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile') meta_profiler_module = ProfilerRegistry(name="patched_modules_for_meta_profile")
# for PyTorch 1.11 compatibility uses # for PyTorch 1.11 compatibility uses
from typing import Dict, List, Tuple, Union
import torch from torch.fx import Node
from torch.fx import GraphModule, Node
from ..._compatibility import compatibility from ..._compatibility import compatibility
...@@ -19,7 +17,7 @@ def calculate_fwd_in(n: Node) -> bool: ...@@ -19,7 +17,7 @@ def calculate_fwd_in(n: Node) -> bool:
Returns: Returns:
save_fwd_in (bool): the result of `save_fwd_in` save_fwd_in (bool): the result of `save_fwd_in`
""" """
return n.meta['save_fwd_in'] return n.meta["save_fwd_in"]
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
...@@ -45,4 +43,4 @@ def calculate_fwd_out(n: Node) -> int: ...@@ -45,4 +43,4 @@ def calculate_fwd_out(n: Node) -> int:
Returns: Returns:
fwd_out (int): the result of `fwd_out` fwd_out (int): the result of `fwd_out`
""" """
return n.meta['fwd_mem_out'] return n.meta["fwd_mem_out"]
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx import GraphModule, Node from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta from .._compatibility import compatibility, is_compatible_with_meta
__all__ = ['activation_size', 'parameter_size', 'is_inplace'] __all__ = ["activation_size", "parameter_size", "is_inplace"]
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
...@@ -63,6 +63,7 @@ def is_inplace(n: Node): ...@@ -63,6 +63,7 @@ def is_inplace(n: Node):
inplace = n.kwargs.get("inplace", False) inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta(): if is_compatible_with_meta():
from .constants import ALIAS_ATEN from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN: if n.target in ALIAS_ATEN:
inplace = True inplace = True
elif n.op == "call_module": elif n.op == "call_module":
......
...@@ -173,8 +173,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: ...@@ -173,8 +173,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
# Inputs[0] contains the shape of the input. # Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape input_shape = inputs[input_arg_index].shape
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], has_affine = (
'shape') else inputs[affine_arg_index] inputs[affine_arg_index].shape is not None
if hasattr(inputs[affine_arg_index], "shape")
else inputs[affine_arg_index]
)
assert 2 <= len(input_shape) <= 5, input_shape assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate # 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
...@@ -188,7 +191,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N ...@@ -188,7 +191,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3] training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training: if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape) input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1) return input_shape * (2 if has_affine else 1)
...@@ -218,15 +221,16 @@ def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> ...@@ -218,15 +221,16 @@ def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) ->
def zero_flop_jit(*args): def zero_flop_jit(*args):
""" """
Count flops for zero flop layers. Count flops for zero flop layers.
""" """
return 0 return 0
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( if version.parse(torch.__version__) >= version.parse("1.12.0") and version.parse(torch.__version__) < version.parse(
torch.__version__) < version.parse('2.0.0'): "2.0.0"
):
flop_mapping = { flop_mapping = {
# gemm, gemv and dot # gemm, gemv and dot
aten.mm.default: matmul_flop_jit, aten.mm.default: matmul_flop_jit,
aten.mv.default: matmul_flop_jit, aten.mv.default: matmul_flop_jit,
aten.dot.default: matmul_flop_jit, aten.dot.default: matmul_flop_jit,
...@@ -234,13 +238,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse ...@@ -234,13 +238,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.addmm.default: addmm_flop_jit, aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit, aten.bmm.default: bmm_flop_jit,
aten.baddbmm.default: baddbmm_flop_jit, aten.baddbmm.default: baddbmm_flop_jit,
# convolution
# convolution
aten.convolution.default: conv_flop_jit, aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit,
# normalization
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit,
...@@ -249,8 +251,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse ...@@ -249,8 +251,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
aten.native_group_norm.default: norm_flop_counter(2, 0), aten.native_group_norm.default: norm_flop_counter(2, 0),
aten.native_group_norm_backward.default: norm_flop_counter(2, 0), aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
# pooling
# pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0), aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d.default: elementwise_flop_counter(1, 0), aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
...@@ -275,7 +276,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse ...@@ -275,7 +276,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
} }
elementwise_flop_aten = [ elementwise_flop_aten = [
# basic op # basic op
aten.add.Tensor, aten.add.Tensor,
aten.add_.Tensor, aten.add_.Tensor,
aten.div.Tensor, aten.div.Tensor,
...@@ -296,8 +297,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse ...@@ -296,8 +297,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.exp.default, aten.exp.default,
aten.sin.default, aten.sin.default,
aten.cos.default, aten.cos.default,
# activation op
# activation op
aten.hardswish.default, aten.hardswish.default,
aten.hardswish_.default, aten.hardswish_.default,
aten.hardswish_backward.default, aten.hardswish_backward.default,
...@@ -320,8 +320,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse ...@@ -320,8 +320,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.tanh.default, aten.tanh.default,
aten.tanh_backward.default, aten.tanh_backward.default,
aten.threshold_backward.default, aten.threshold_backward.default,
# dropout
# dropout
aten.native_dropout.default, aten.native_dropout.default,
aten.native_dropout_backward.default, aten.native_dropout_backward.default,
] ]
...@@ -362,7 +361,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse ...@@ -362,7 +361,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.zero_.default, aten.zero_.default,
aten.zeros_like.default, aten.zeros_like.default,
aten.fill_.Scalar, aten.fill_.Scalar,
aten.stack.default aten.stack.default,
] # yapf: disable ] # yapf: disable
for op in zero_flop_aten: for op in zero_flop_aten:
......
...@@ -15,7 +15,7 @@ from .memory_utils import activation_size, parameter_size ...@@ -15,7 +15,7 @@ from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping from .opcount import flop_mapping
from .tensor import MetaTensor from .tensor import MetaTensor
__all__ = ['profile_function', 'profile_module', 'profile_method'] __all__ = ["profile_function", "profile_module", "profile_method"]
# super-dainiu: this cache should be global, otherwise it cannot # super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes # track duplicated tensors between nodes
...@@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G ...@@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
# backward is executed. # backward is executed.
# Hopefully, this attempt will provide a better estimation of memory. # Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor): class FlopTensor(MetaTensor):
_node: Node = None _node: Node = None
def __repr__(self): def __repr__(self):
...@@ -186,24 +185,24 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G ...@@ -186,24 +185,24 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args) args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs) kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node) node = subgraph.create_node("call_function", func, args_node, kwargs_node)
out = super().__torch_dispatch__(func, types, args, kwargs) out = super().__torch_dispatch__(func, types, args, kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['phase'] = phase node.meta["phase"] = phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs, # super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during # i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD` # `Phase.FORWARD`
if phase == Phase.FORWARD: if phase == Phase.FORWARD:
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
node.meta['phase'] = Phase.PLACEHOLDER node.meta["phase"] = Phase.PLACEHOLDER
# TODO(yby): specify `saved_tensors` for backward memory estimation # TODO(yby): specify `saved_tensors` for backward memory estimation
node.meta['saved_tensor'] = [] node.meta["saved_tensor"] = []
if phase == Phase.BACKWARD: if phase == Phase.BACKWARD:
node.meta['saved_tensor'] = normalize_tuple(out) node.meta["saved_tensor"] = normalize_tuple(out)
def wrap(x): def wrap(x):
if isinstance(x, MetaTensor): if isinstance(x, MetaTensor):
...@@ -219,11 +218,14 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G ...@@ -219,11 +218,14 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
x = FlopTensor(x) x = FlopTensor(x)
if is_autogradable(x): if is_autogradable(x):
x.requires_grad_(True) x.requires_grad_(True)
x._node = subgraph.create_node('placeholder', x._node = subgraph.create_node(
'placeholder', (subgraph._root,), "placeholder",
name=subgraph._graph_namespace.create_name('input', x._tensor)) "placeholder",
x._node.meta['phase'] = Phase.PLACEHOLDER (subgraph._root,),
x._node.meta['saved_tensor'] = [] name=subgraph._graph_namespace.create_name("input", x._tensor),
)
x._node.meta["phase"] = Phase.PLACEHOLDER
x._node.meta["saved_tensor"] = []
return x return x
# Basically, we need to detach the args and kwargs from the outer graph. # Basically, we need to detach the args and kwargs from the outer graph.
...@@ -235,7 +237,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G ...@@ -235,7 +237,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache: if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach() tensor = x._tensor.detach()
tensor.data_ptr = x._tensor.data_ptr tensor.data_ptr = x._tensor.data_ptr
x._node.meta['saved_tensor'] += [tensor] x._node.meta["saved_tensor"] += [tensor]
if not do_not_cache: if not do_not_cache:
cache.add(x._tensor.data_ptr()) cache.add(x._tensor.data_ptr())
return x return x
...@@ -284,7 +286,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G ...@@ -284,7 +286,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable: def profile_function(target: "Target", device: str = "meta") -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
...@@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: ...@@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
""" """
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# find the grad for parameter in args and kwargs # find the grad for parameter in args and kwargs
param_size = 0 param_size = 0
...@@ -316,18 +317,18 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: ...@@ -316,18 +317,18 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
# still run the profiling but discard some results regarding `target` # still run the profiling but discard some results regarding `target`
global do_not_cache global do_not_cache
inplace = kwargs.get('inplace', False) inplace = kwargs.get("inplace", False)
if target in OUTPUT_SAVED_OPS: if target in OUTPUT_SAVED_OPS:
do_not_cache = True do_not_cache = True
if inplace: if inplace:
do_not_cache = True do_not_cache = True
kwargs['inplace'] = False kwargs["inplace"] = False
if device == 'meta': if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs) out, meta = _profile_meta(func, *args, **kwargs)
else: else:
out, meta = _profile_concrete(func, *args, **kwargs) out, meta = _profile_concrete(func, *args, **kwargs)
if inplace: if inplace:
kwargs['inplace'] = True kwargs["inplace"] = True
meta.bwd_mem_tmp = 0 meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0 meta.bwd_mem_out = 0
do_not_cache = False do_not_cache = False
...@@ -341,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: ...@@ -341,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_method(target: 'Target', device: str = 'meta') -> Callable: def profile_method(target: "Target", device: str = "meta") -> Callable:
""" """
Wrap a `call_method` node Wrap a `call_method` node
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
...@@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: ...@@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result # execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.' assert isinstance(target, str), f"{target} instance is not str."
if device == 'meta': if device == "meta":
out, meta = _profile_meta(target, *args, **kwargs) out, meta = _profile_meta(target, *args, **kwargs)
else: else:
out, meta = _profile_concrete(target, *args, **kwargs) out, meta = _profile_concrete(target, *args, **kwargs)
...@@ -360,7 +361,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: ...@@ -360,7 +361,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable:
""" """
Wrap a `call_module` node or `torch.nn` in order to Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
...@@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: ...@@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
""" """
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# calculate parameter size # calculate parameter size
param_size = parameter_size(module) param_size = parameter_size(module)
...@@ -384,13 +384,13 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: ...@@ -384,13 +384,13 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
# still run the profiling but discard some results regarding `module`. # still run the profiling but discard some results regarding `module`.
global do_not_cache global do_not_cache
inplace = getattr(module, 'inplace', False) inplace = getattr(module, "inplace", False)
if type(module) in OUTPUT_SAVED_MOD: if type(module) in OUTPUT_SAVED_MOD:
do_not_cache = True do_not_cache = True
if inplace: if inplace:
do_not_cache = True do_not_cache = True
module.inplace = False module.inplace = False
if device == 'meta': if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs) out, meta = _profile_meta(func, *args, **kwargs)
else: else:
out, meta = _profile_concrete(func, *args, **kwargs) out, meta = _profile_concrete(func, *args, **kwargs)
......
...@@ -59,9 +59,9 @@ def calculate_fwd_tmp(n: Node) -> int: ...@@ -59,9 +59,9 @@ def calculate_fwd_tmp(n: Node) -> int:
Returns: Returns:
bool: Whether the node is a ReLU-like node bool: Whether the node is a ReLU-like node
""" """
if n.op == 'call_function': if n.op == "call_function":
return n.target in OUTPUT_SAVED_OPS return n.target in OUTPUT_SAVED_OPS
elif n.op == 'call_module': elif n.op == "call_module":
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
return False return False
......
import uuid import uuid
import torch import torch
from torch.types import _bool, _device, _dtype from torch.types import _device
from torch.utils._pytree import tree_flatten, tree_map from torch.utils._pytree import tree_map
from .._compatibility import compatibility from .._compatibility import compatibility
from .constants import ALIAS_ATEN from .constants import ALIAS_ATEN
__all__ = ['MetaTensor'] __all__ = ["MetaTensor"]
def set_data_ptr(x): def set_data_ptr(x):
...@@ -43,12 +43,13 @@ class MetaTensor(torch.Tensor): ...@@ -43,12 +43,13 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(), storage_offset=elem.storage_offset(),
dtype=elem.dtype, dtype=elem.dtype,
layout=elem.layout, layout=elem.layout,
device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
requires_grad=elem.requires_grad) # deceive the frontend for aten selections requires_grad=elem.requires_grad,
) # deceive the frontend for aten selections
r._tensor = elem r._tensor = elem
# ...the real tensor is held as an element on the tensor. # ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta: if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta')) r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta` # only tensor not on `meta` should be copied to `meta`
set_data_ptr(r._tensor) set_data_ptr(r._tensor)
return r return r
...@@ -69,15 +70,15 @@ class MetaTensor(torch.Tensor): ...@@ -69,15 +70,15 @@ class MetaTensor(torch.Tensor):
x = x._tensor x = x._tensor
elif isinstance(x, torch.Tensor): elif isinstance(x, torch.Tensor):
fake_device = x.device fake_device = x.device
x = x.to(torch.device('meta')) x = x.to(torch.device("meta"))
return x return x
args = tree_map(unwrap, args) args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs) kwargs = tree_map(unwrap, kwargs)
if 'device' in kwargs: if "device" in kwargs:
fake_device = kwargs['device'] fake_device = kwargs["device"]
kwargs['device'] = torch.device('meta') kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta # run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs) out = func(*args, **kwargs)
...@@ -93,7 +94,7 @@ class MetaTensor(torch.Tensor): ...@@ -93,7 +94,7 @@ class MetaTensor(torch.Tensor):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
nonlocal fake_device nonlocal fake_device
if not x.is_meta: if not x.is_meta:
x = x.to(torch.device('meta')) x = x.to(torch.device("meta"))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out) return tree_map(wrap, out)
...@@ -120,18 +121,18 @@ class MetaTensor(torch.Tensor): ...@@ -120,18 +121,18 @@ class MetaTensor(torch.Tensor):
nonlocal fake_device nonlocal fake_device
if isinstance(x, str) or isinstance(x, _device): if isinstance(x, str) or isinstance(x, _device):
fake_device = x fake_device = x
return 'meta' return "meta"
return x return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, fake_device=fake_device) return MetaTensor(elem, fake_device=fake_device)
def cpu(self, *args, **kwargs): def cpu(self, *args, **kwargs):
if self.device.type == 'cpu': if self.device.type == "cpu":
return self.to(*args, **kwargs) return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs) return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False): def cuda(self, device=None, non_blocking=False):
if device is not None: if device is not None:
return self.to(device=device, non_blocking=non_blocking) return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking) return self.to(device="cuda:0", non_blocking=non_blocking)
import operator from typing import Any
from typing import Any, List, Union
import torch import torch
from torch.fx.proxy import Attribute, Proxy from torch.fx.proxy import Proxy
from colossalai.fx.tracer.meta_patch import meta_patched_function from colossalai.fx.tracer.meta_patch import meta_patched_function
__all__ = ['ColoProxy'] __all__ = ["ColoProxy"]
class ColoProxy(Proxy): class ColoProxy(Proxy):
...@@ -39,11 +38,12 @@ class ColoProxy(Proxy): ...@@ -39,11 +38,12 @@ class ColoProxy(Proxy):
return self._meta_data is not None return self._meta_data is not None
def _assert_meta_data_is_tensor(self): def _assert_meta_data_is_tensor(self):
assert torch.is_tensor( assert (
self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}' torch.is_tensor(self._meta_data) and self._meta_data.is_meta
), f"Meta data is not a meta tensor for {self.node.name}"
def _assert_has_meta_data(self): def _assert_has_meta_data(self):
assert self._meta_data is not None, f'Meta data is not set for {self.node.name}' assert self._meta_data is not None, f"Meta data is not set for {self.node.name}"
def __len__(self): def __len__(self):
self._assert_has_meta_data() self._assert_has_meta_data()
...@@ -62,7 +62,6 @@ class ColoProxy(Proxy): ...@@ -62,7 +62,6 @@ class ColoProxy(Proxy):
return self.meta_data return self.meta_data
def __getattr__(self, k): def __getattr__(self, k):
return ColoAttribute(self, k) return ColoAttribute(self, k)
def __contains__(self, key): def __contains__(self, key):
...@@ -92,7 +91,6 @@ def extract_meta(*args, **kwargs): ...@@ -92,7 +91,6 @@ def extract_meta(*args, **kwargs):
class ColoAttribute(ColoProxy): class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str): def __init__(self, root, attr: str):
self.root = root self.root = root
self.attr = attr self.attr = attr
......
...@@ -39,7 +39,7 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr ...@@ -39,7 +39,7 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
_tensor: torch.Tensor _tensor: torch.Tensor
_node: Node _node: Node
__slots__ = ['_tensor', '_node'] __slots__ = ["_tensor", "_node"]
@staticmethod @staticmethod
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
...@@ -51,22 +51,22 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr ...@@ -51,22 +51,22 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
dtype=tensor.dtype, dtype=tensor.dtype,
layout=tensor.layout, layout=tensor.layout,
device=fake_device if fake_device is not None else tensor.device, device=fake_device if fake_device is not None else tensor.device,
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections requires_grad=tensor.requires_grad,
) # deceive the frontend for aten selections
r._tensor = tensor r._tensor = tensor
if placeholder: if placeholder:
if name is None: if name is None:
name = 'input' name = "input"
r._node = graph.create_node('placeholder', r._node = graph.create_node(
'placeholder', (graph._root,), "placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor)
name=namespace.create_name(name, tensor)) )
# ...the real tensor is held as an element on the tensor. # ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta: if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta')) r._tensor = r._tensor.to(torch.device("meta"))
return r return r
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x): def unwrap(x):
nonlocal fake_device nonlocal fake_device
if isinstance(x, MetaProxy): if isinstance(x, MetaProxy):
...@@ -75,21 +75,21 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr ...@@ -75,21 +75,21 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
# assert not isinstance(x, MetaProxy) # assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor): elif isinstance(x, torch.Tensor):
fake_device = x.device fake_device = x.device
x = x.to(torch.device('meta')) x = x.to(torch.device("meta"))
return x return x
def get_node(x): def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): if isinstance(x, torch.Tensor) and not hasattr(x, "_node"):
x = MetaProxy(x, placeholder=True, name='weight') x = MetaProxy(x, placeholder=True, name="weight")
return x if not hasattr(x, '_node') else x._node return x if not hasattr(x, "_node") else x._node
args_node = tree_map(get_node, args) args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs) kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node) node = graph.create_node("call_function", func, args_node, kwargs_node)
if 'device' in kwargs: if "device" in kwargs:
fake_device = kwargs['device'] fake_device = kwargs["device"]
kwargs['device'] = torch.device('meta') kwargs["device"] = torch.device("meta")
args = tree_map(unwrap, args) args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs) kwargs = tree_map(unwrap, kwargs)
...@@ -103,9 +103,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr ...@@ -103,9 +103,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
nonlocal fake_device nonlocal fake_device
if not x.is_meta: if not x.is_meta:
x = x.to(torch.device('meta')) x = x.to(torch.device("meta"))
return MetaProxy( return (
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x MetaProxy(x, fake_device=fake_device)
if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor")
else x
)
def set_node(x): def set_node(x):
x._node = node x._node = node
...@@ -125,9 +128,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr ...@@ -125,9 +128,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
for tensor in normalize_tuple(out): for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad: if is_autogradable(tensor) and tensor.requires_grad:
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( grad = (
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta')) torch.empty_like(tensor._tensor, device=torch.device("meta"))
torch.autograd.backward(tensor, if isinstance(tensor, MetaProxy)
MetaProxy(grad, fake_device=tensor.device, placeholder=True), else torch.empty_like(tensor, device=torch.device("meta"))
retain_graph=True) )
torch.autograd.backward(
tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True
)
return graph return graph
...@@ -2,10 +2,10 @@ from typing import Any, List, Union ...@@ -2,10 +2,10 @@ from typing import Any, List, Union
import torch import torch
from ..proxy import ColoAttribute, ColoProxy from ..proxy import ColoProxy
from .meta_patch import meta_patched_function, meta_patched_module from .meta_patch import meta_patched_function
__all__ = ['is_element_in_list', 'extract_meta'] __all__ = ["is_element_in_list", "extract_meta"]
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
...@@ -21,7 +21,6 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): ...@@ -21,7 +21,6 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
def extract_meta(*args, **kwargs): def extract_meta(*args, **kwargs):
def _convert(val): def _convert(val):
if isinstance(val, ColoProxy): if isinstance(val, ColoProxy):
return val.meta_data return val.meta_data
......
import operator
import torch import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc from .bias_addition_function import LinearBasedBiasFunc
...@@ -10,13 +7,12 @@ from .bias_addition_function import LinearBasedBiasFunc ...@@ -10,13 +7,12 @@ from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addbmm) @bias_addition_method.register(torch.Tensor.addbmm)
@bias_addition_function.register(torch.addbmm) @bias_addition_function.register(torch.addbmm)
class Addbmm(LinearBasedBiasFunc): class Addbmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self): def extract_kwargs_from_origin_func(self):
kwargs = {} kwargs = {}
if 'beta' in self.kwargs: if "beta" in self.kwargs:
kwargs['beta'] = self.kwargs['beta'] kwargs["beta"] = self.kwargs["beta"]
if 'alpha' in self.kwargs: if "alpha" in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha'] kwargs["alpha"] = self.kwargs["alpha"]
return kwargs return kwargs
def create_non_bias_func_proxy(self, input_proxy, other_proxy): def create_non_bias_func_proxy(self, input_proxy, other_proxy):
...@@ -25,7 +21,7 @@ class Addbmm(LinearBasedBiasFunc): ...@@ -25,7 +21,7 @@ class Addbmm(LinearBasedBiasFunc):
compute the main computation, such as convolution, with bias option banned. compute the main computation, such as convolution, with bias option banned.
""" """
assert self.substitute_func == torch.bmm assert self.substitute_func == torch.bmm
node_kind = 'call_function' node_kind = "call_function"
node_target = self.substitute_func node_target = self.substitute_func
node_args = (input_proxy, other_proxy) node_args = (input_proxy, other_proxy)
...@@ -35,10 +31,10 @@ class Addbmm(LinearBasedBiasFunc): ...@@ -35,10 +31,10 @@ class Addbmm(LinearBasedBiasFunc):
return non_bias_func_proxy return non_bias_func_proxy
def insert_sum_node(self, input_proxy, sum_dims=0): def insert_sum_node(self, input_proxy, sum_dims=0):
''' """
This method is used to sum the input_proxy through the sum_dims. This method is used to sum the input_proxy through the sum_dims.
''' """
node_kind = 'call_function' node_kind = "call_function"
node_target = torch.sum node_target = torch.sum
node_args = (input_proxy, sum_dims) node_args = (input_proxy, sum_dims)
node_kwargs = {} node_kwargs = {}
...@@ -55,15 +51,15 @@ class Addbmm(LinearBasedBiasFunc): ...@@ -55,15 +51,15 @@ class Addbmm(LinearBasedBiasFunc):
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy) sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
kwargs = self.extract_kwargs_from_origin_func() kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs: if "beta" in kwargs:
beta = kwargs['beta'] beta = kwargs["beta"]
# doing the multiplication with beta if it exists(temp_2 = beta * input) # doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy = self.create_mul_node(self.args[0], beta) beta_proxy = self.create_mul_node(self.args[0], beta)
else: else:
beta_proxy = self.args[0] beta_proxy = self.args[0]
if 'alpha' in kwargs: if "alpha" in kwargs:
alpha = kwargs['alpha'] alpha = kwargs["alpha"]
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1) # doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy = self.create_mul_node(alpha, sum_proxy) alpha_proxy = self.create_mul_node(alpha, sum_proxy)
else: else:
......
import operator
import torch import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc from .bias_addition_function import LinearBasedBiasFunc
...@@ -10,17 +7,16 @@ from .bias_addition_function import LinearBasedBiasFunc ...@@ -10,17 +7,16 @@ from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addmm) @bias_addition_method.register(torch.Tensor.addmm)
@bias_addition_function.register(torch.addmm) @bias_addition_function.register(torch.addmm)
class Addmm(LinearBasedBiasFunc): class Addmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self): def extract_kwargs_from_origin_func(self):
kwargs = {} kwargs = {}
if 'beta' in self.kwargs: if "beta" in self.kwargs:
kwargs['beta'] = self.kwargs['beta'] kwargs["beta"] = self.kwargs["beta"]
if 'alpha' in self.kwargs: if "alpha" in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha'] kwargs["alpha"] = self.kwargs["alpha"]
return kwargs return kwargs
def transpose_other_operand_for_linear(self, other_proxy): def transpose_other_operand_for_linear(self, other_proxy):
''' """
This method is used to transpose the other operand for linear function. This method is used to transpose the other operand for linear function.
For example: For example:
input = torch.rand(3, 4) input = torch.rand(3, 4)
...@@ -30,8 +26,8 @@ class Addmm(LinearBasedBiasFunc): ...@@ -30,8 +26,8 @@ class Addmm(LinearBasedBiasFunc):
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2 # To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
# before we call the linear function. # before we call the linear function.
new_output = torch.linear(m1, m2.transpose(0, 1)) + input new_output = torch.linear(m1, m2.transpose(0, 1)) + input
''' """
node_kind = 'call_function' node_kind = "call_function"
node_target = torch.transpose node_target = torch.transpose
node_args = (other_proxy, 0, 1) node_args = (other_proxy, 0, 1)
node_kwargs = {} node_kwargs = {}
...@@ -43,14 +39,14 @@ class Addmm(LinearBasedBiasFunc): ...@@ -43,14 +39,14 @@ class Addmm(LinearBasedBiasFunc):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy) non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy)
kwargs = self.extract_kwargs_from_origin_func() kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs: if "beta" in kwargs:
beta = kwargs['beta'] beta = kwargs["beta"]
beta_proxy = self.create_mul_node(self.args[0], beta) beta_proxy = self.create_mul_node(self.args[0], beta)
else: else:
beta_proxy = self.args[0] beta_proxy = self.args[0]
if 'alpha' in kwargs: if "alpha" in kwargs:
alpha = kwargs['alpha'] alpha = kwargs["alpha"]
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy) alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
else: else:
alpha_proxy = non_bias_linear_func_proxy alpha_proxy = non_bias_linear_func_proxy
......
...@@ -29,7 +29,6 @@ class BiasAdditionFunc(ABC): ...@@ -29,7 +29,6 @@ class BiasAdditionFunc(ABC):
to insert two more operator.mul nodes for the computation graph to compute the to insert two more operator.mul nodes for the computation graph to compute the
final result. final result.
""" """
pass
@abstractmethod @abstractmethod
def generate(self): def generate(self):
...@@ -50,7 +49,6 @@ class BiasAdditionFunc(ABC): ...@@ -50,7 +49,6 @@ class BiasAdditionFunc(ABC):
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
""" """
pass
def create_mul_node(self, input_proxy, coefficent): def create_mul_node(self, input_proxy, coefficent):
""" """
...@@ -59,7 +57,7 @@ class BiasAdditionFunc(ABC): ...@@ -59,7 +57,7 @@ class BiasAdditionFunc(ABC):
Therefore, we need to use this method insert two more operator.mul nodes for Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result. the computation graph to compute the final result.
""" """
node_kind = 'call_function' node_kind = "call_function"
node_target = operator.mul node_target = operator.mul
node_args = ( node_args = (
input_proxy, input_proxy,
...@@ -82,7 +80,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc): ...@@ -82,7 +80,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
compute the main computation, such as convolution, with bias option banned. compute the main computation, such as convolution, with bias option banned.
""" """
assert self.substitute_func == torch.nn.functional.linear assert self.substitute_func == torch.nn.functional.linear
node_kind = 'call_function' node_kind = "call_function"
node_target = self.substitute_func node_target = self.substitute_func
node_args = (input_proxy, other_proxy) node_args = (input_proxy, other_proxy)
...@@ -96,7 +94,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc): ...@@ -96,7 +94,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
This method is used to create the bias_addition_proxy, the node created by this proxy will This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed. compute the sum of non_bias_func result and bias with some reshape operation if needed.
""" """
bias_add_node_kind = 'call_function' bias_add_node_kind = "call_function"
bias_add_node_target = operator.add bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy) bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
......
import operator
import torch
import torch.nn.functional as F import torch.nn.functional as F
from ...registry import bias_addition_function from ...registry import bias_addition_function
...@@ -9,17 +6,16 @@ from .bias_addition_function import LinearBasedBiasFunc ...@@ -9,17 +6,16 @@ from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_function.register(F.linear) @bias_addition_function.register(F.linear)
class Linear(LinearBasedBiasFunc): class Linear(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self): def extract_kwargs_from_origin_func(self):
assert 'bias' in self.kwargs assert "bias" in self.kwargs
kwargs = {} kwargs = {}
if 'bias' in self.kwargs: if "bias" in self.kwargs:
kwargs['bias'] = self.kwargs['bias'] kwargs["bias"] = self.kwargs["bias"]
return kwargs return kwargs
def generate(self): def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1]) non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
kwargs = self.extract_kwargs_from_origin_func() kwargs = self.extract_kwargs_from_origin_func()
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias']) bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs["bias"])
return bias_addition_proxy return bias_addition_proxy
...@@ -27,8 +27,8 @@ class BiasAdditionModule(ABC): ...@@ -27,8 +27,8 @@ class BiasAdditionModule(ABC):
Note: this function will be invoked during module initializing, Note: this function will be invoked during module initializing,
you should never call this function. you should never call this function.
""" """
weight_node_kind = 'get_attr' weight_node_kind = "get_attr"
weight_node_target = self.target + '.weight' weight_node_target = self.target + ".weight"
weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {}) weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {})
return weight_proxy return weight_proxy
...@@ -39,8 +39,8 @@ class BiasAdditionModule(ABC): ...@@ -39,8 +39,8 @@ class BiasAdditionModule(ABC):
Note: this function will be invoked during module initializing, Note: this function will be invoked during module initializing,
you should never call this function. you should never call this function.
""" """
bias_node_kind = 'get_attr' bias_node_kind = "get_attr"
bias_node_target = self.target + '.bias' bias_node_target = self.target + ".bias"
bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {}) bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {})
return bias_proxy return bias_proxy
...@@ -54,14 +54,13 @@ class BiasAdditionModule(ABC): ...@@ -54,14 +54,13 @@ class BiasAdditionModule(ABC):
considered during module initializing. However, we need to consider those attributes as kwargs considered during module initializing. However, we need to consider those attributes as kwargs
in F.conv2d. in F.conv2d.
""" """
pass
def create_non_bias_func_proxy(self, input_proxy=None): def create_non_bias_func_proxy(self, input_proxy=None):
""" """
This method is used to create the non_bias_func proxy, the node created by this proxy will This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned. compute the main computation, such as convolution, with bias option banned.
""" """
node_kind = 'call_function' node_kind = "call_function"
node_target = self.substitute_func node_target = self.substitute_func
if input_proxy is None: if input_proxy is None:
input_proxy = self.args[0] input_proxy = self.args[0]
...@@ -75,7 +74,7 @@ class BiasAdditionModule(ABC): ...@@ -75,7 +74,7 @@ class BiasAdditionModule(ABC):
This method is used to create the bias_addition_proxy, the node created by this proxy will This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed. compute the sum of non_bias_func result and bias with some reshape operation if needed.
""" """
bias_add_node_kind = 'call_function' bias_add_node_kind = "call_function"
bias_add_node_target = operator.add bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy) bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
...@@ -100,7 +99,6 @@ class BiasAdditionModule(ABC): ...@@ -100,7 +99,6 @@ class BiasAdditionModule(ABC):
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
""" """
pass
module_to_func_dict = { module_to_func_dict = {
......
import torch import torch
import torch.nn.functional as F from torch.nn.modules.utils import _pair, _single, _triple
from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple
from ...registry import bias_addition_module from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule from .bias_addition_module import BiasAdditionModule
...@@ -10,17 +9,16 @@ from .bias_addition_module import BiasAdditionModule ...@@ -10,17 +9,16 @@ from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Conv2d) @bias_addition_module.register(torch.nn.Conv2d)
@bias_addition_module.register(torch.nn.Conv3d) @bias_addition_module.register(torch.nn.Conv3d)
class BiasAdditionConv(BiasAdditionModule): class BiasAdditionConv(BiasAdditionModule):
def extract_kwargs_from_mod(self): def extract_kwargs_from_mod(self):
root = self.tracer.root root = self.tracer.root
conv_module = root.get_submodule(self.target) conv_module = root.get_submodule(self.target)
kwarg_attributes = ['groups', 'dilation', 'stride'] kwarg_attributes = ["groups", "dilation", "stride"]
non_bias_kwargs = {} non_bias_kwargs = {}
for attr_name in kwarg_attributes: for attr_name in kwarg_attributes:
if hasattr(conv_module, attr_name): if hasattr(conv_module, attr_name):
non_bias_kwargs[attr_name] = getattr(conv_module, attr_name) non_bias_kwargs[attr_name] = getattr(conv_module, attr_name)
if conv_module.padding_mode != "zeros": if conv_module.padding_mode != "zeros":
#TODO: non zeros mode requires some extra processing for input # TODO: non zeros mode requires some extra processing for input
conv_type = type(conv_module) conv_type = type(conv_module)
if conv_type == "torch.nn.Conv1d": if conv_type == "torch.nn.Conv1d":
padding_element = _single(0) padding_element = _single(0)
...@@ -28,9 +26,9 @@ class BiasAdditionConv(BiasAdditionModule): ...@@ -28,9 +26,9 @@ class BiasAdditionConv(BiasAdditionModule):
padding_element = _pair(0) padding_element = _pair(0)
elif conv_type == "torch.nn.Conv3d": elif conv_type == "torch.nn.Conv3d":
padding_element = _triple(0) padding_element = _triple(0)
non_bias_kwargs['padding'] = padding_element non_bias_kwargs["padding"] = padding_element
else: else:
non_bias_kwargs['padding'] = getattr(conv_module, 'padding') non_bias_kwargs["padding"] = getattr(conv_module, "padding")
return non_bias_kwargs return non_bias_kwargs
...@@ -41,11 +39,12 @@ class BiasAdditionConv(BiasAdditionModule): ...@@ -41,11 +39,12 @@ class BiasAdditionConv(BiasAdditionModule):
""" """
bias_shape = [1] * (dimensions - 1) bias_shape = [1] * (dimensions - 1)
bias_shape[0] = -1 bias_shape[0] = -1
bias_reshape_node_kind = 'call_method' bias_reshape_node_kind = "call_method"
bias_reshape_node_target = 'view' bias_reshape_node_target = "view"
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape)) bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_proxy = self.tracer.create_proxy(
bias_reshape_node_args, {}) bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {}
)
return bias_reshape_proxy return bias_reshape_proxy
def generate(self): def generate(self):
......
import torch import torch
import torch.nn.functional as F
from ...registry import bias_addition_module from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule from .bias_addition_module import BiasAdditionModule
...@@ -7,7 +6,6 @@ from .bias_addition_module import BiasAdditionModule ...@@ -7,7 +6,6 @@ from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Linear) @bias_addition_module.register(torch.nn.Linear)
class BiasAdditionLinear(BiasAdditionModule): class BiasAdditionLinear(BiasAdditionModule):
def extract_kwargs_from_mod(self): def extract_kwargs_from_mod(self):
return {} return {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment