Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
from typing import Tuple
import torch
from ..registry import meta_profiler_module
......
from functools import reduce
import operator
from functools import reduce
from typing import Optional, Tuple
import torch
from ..registry import meta_profiler_module
from typing import Optional, Tuple, Union
def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor,
w_hh: torch.Tensor) -> Tuple[int, int]:
def _rnn_flops(
flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, w_hh: torch.Tensor
) -> Tuple[int, int]:
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
# matrix matrix mult ih state and internal state
......@@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
flops = 0
macs = 0
for i in range(self.num_layers):
w_ih = self.__getattr__('weight_ih_l' + str(i))
w_hh = self.__getattr__('weight_hh_l' + str(i))
w_ih = self.__getattr__("weight_ih_l" + str(i))
w_hh = self.__getattr__("weight_hh_l" + str(i))
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
b_ih = self.__getattr__('bias_ih_l' + str(i))
b_hh = self.__getattr__('bias_hh_l' + str(i))
b_ih = self.__getattr__("bias_ih_l" + str(i))
b_hh = self.__getattr__("bias_hh_l" + str(i))
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= reduce(operator.mul, input.shape[:2])
macs *= reduce(operator.mul, input.shape[:2])
......@@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = 0
macs = 0
w_ih = self.__getattr__('weight_ih_l')
w_hh = self.__getattr__('weight_hh_l')
w_ih = self.__getattr__("weight_ih_l")
w_hh = self.__getattr__("weight_hh_l")
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
b_ih = self.__getattr__('bias_ih_l')
b_hh = self.__getattr__('bias_hh_l')
b_ih = self.__getattr__("bias_ih_l")
b_hh = self.__getattr__("bias_hh_l")
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= input.shape[0]
macs *= input.shape[0]
......
import operator
from typing import Tuple
import torch
from ..registry import meta_profiler_module
from typing import Optional, Tuple, Union
@meta_profiler_module.register(torch.nn.Flatten)
......
class ProfilerRegistry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
self.store[source] = func
return func
......@@ -21,5 +19,5 @@ class ProfilerRegistry:
return source in self.store
meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile')
meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile')
meta_profiler_function = ProfilerRegistry(name="patched_functions_for_meta_profile")
meta_profiler_module = ProfilerRegistry(name="patched_modules_for_meta_profile")
# for PyTorch 1.11 compatibility uses
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from torch.fx import Node
from ..._compatibility import compatibility
......@@ -19,7 +17,7 @@ def calculate_fwd_in(n: Node) -> bool:
Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
return n.meta['save_fwd_in']
return n.meta["save_fwd_in"]
@compatibility(is_backward_compatible=True)
......@@ -45,4 +43,4 @@ def calculate_fwd_out(n: Node) -> int:
Returns:
fwd_out (int): the result of `fwd_out`
"""
return n.meta['fwd_mem_out']
return n.meta["fwd_mem_out"]
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
__all__ = ["activation_size", "parameter_size", "is_inplace"]
@compatibility(is_backward_compatible=True)
......@@ -63,6 +63,7 @@ def is_inplace(n: Node):
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
......
......@@ -173,8 +173,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
'shape') else inputs[affine_arg_index]
has_affine = (
inputs[affine_arg_index].shape is not None
if hasattr(inputs[affine_arg_index], "shape")
else inputs[affine_arg_index]
)
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
......@@ -188,7 +191,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
......@@ -218,15 +221,16 @@ def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) ->
def zero_flop_jit(*args):
"""
Count flops for zero flop layers.
Count flops for zero flop layers.
"""
return 0
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0") and version.parse(torch.__version__) < version.parse(
"2.0.0"
):
flop_mapping = {
# gemm, gemv and dot
# gemm, gemv and dot
aten.mm.default: matmul_flop_jit,
aten.mv.default: matmul_flop_jit,
aten.dot.default: matmul_flop_jit,
......@@ -234,13 +238,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
aten.baddbmm.default: baddbmm_flop_jit,
# convolution
# convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
# normalization
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
......@@ -249,8 +251,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
aten.native_group_norm.default: norm_flop_counter(2, 0),
aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
# pooling
# pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
......@@ -275,7 +276,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
}
elementwise_flop_aten = [
# basic op
# basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
......@@ -296,8 +297,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.exp.default,
aten.sin.default,
aten.cos.default,
# activation op
# activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
......@@ -320,8 +320,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
# dropout
# dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
]
......@@ -362,7 +361,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.zero_.default,
aten.zeros_like.default,
aten.fill_.Scalar,
aten.stack.default
aten.stack.default,
] # yapf: disable
for op in zero_flop_aten:
......
......@@ -15,7 +15,7 @@ from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping
from .tensor import MetaTensor
__all__ = ['profile_function', 'profile_module', 'profile_method']
__all__ = ["profile_function", "profile_module", "profile_method"]
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
......@@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
# backward is executed.
# Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor):
_node: Node = None
def __repr__(self):
......@@ -186,24 +185,24 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
node = subgraph.create_node("call_function", func, args_node, kwargs_node)
out = super().__torch_dispatch__(func, types, args, kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['phase'] = phase
node.meta["phase"] = phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD`
if phase == Phase.FORWARD:
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
node.meta['phase'] = Phase.PLACEHOLDER
node.meta["phase"] = Phase.PLACEHOLDER
# TODO(yby): specify `saved_tensors` for backward memory estimation
node.meta['saved_tensor'] = []
node.meta["saved_tensor"] = []
if phase == Phase.BACKWARD:
node.meta['saved_tensor'] = normalize_tuple(out)
node.meta["saved_tensor"] = normalize_tuple(out)
def wrap(x):
if isinstance(x, MetaTensor):
......@@ -219,11 +218,14 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
x = FlopTensor(x)
if is_autogradable(x):
x.requires_grad_(True)
x._node = subgraph.create_node('placeholder',
'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['saved_tensor'] = []
x._node = subgraph.create_node(
"placeholder",
"placeholder",
(subgraph._root,),
name=subgraph._graph_namespace.create_name("input", x._tensor),
)
x._node.meta["phase"] = Phase.PLACEHOLDER
x._node.meta["saved_tensor"] = []
return x
# Basically, we need to detach the args and kwargs from the outer graph.
......@@ -235,7 +237,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach()
tensor.data_ptr = x._tensor.data_ptr
x._node.meta['saved_tensor'] += [tensor]
x._node.meta["saved_tensor"] += [tensor]
if not do_not_cache:
cache.add(x._tensor.data_ptr())
return x
......@@ -284,7 +286,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
def profile_function(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
......@@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# find the grad for parameter in args and kwargs
param_size = 0
......@@ -316,18 +317,18 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
# still run the profiling but discard some results regarding `target`
global do_not_cache
inplace = kwargs.get('inplace', False)
inplace = kwargs.get("inplace", False)
if target in OUTPUT_SAVED_OPS:
do_not_cache = True
if inplace:
do_not_cache = True
kwargs['inplace'] = False
if device == 'meta':
kwargs["inplace"] = False
if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
kwargs['inplace'] = True
kwargs["inplace"] = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
......@@ -341,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
def profile_method(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
......@@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
if device == 'meta':
assert isinstance(target, str), f"{target} instance is not str."
if device == "meta":
out, meta = _profile_meta(target, *args, **kwargs)
else:
out, meta = _profile_concrete(target, *args, **kwargs)
......@@ -360,7 +361,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
......@@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# calculate parameter size
param_size = parameter_size(module)
......@@ -384,13 +384,13 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
# still run the profiling but discard some results regarding `module`.
global do_not_cache
inplace = getattr(module, 'inplace', False)
inplace = getattr(module, "inplace", False)
if type(module) in OUTPUT_SAVED_MOD:
do_not_cache = True
if inplace:
do_not_cache = True
module.inplace = False
if device == 'meta':
if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
......
......@@ -59,9 +59,9 @@ def calculate_fwd_tmp(n: Node) -> int:
Returns:
bool: Whether the node is a ReLU-like node
"""
if n.op == 'call_function':
if n.op == "call_function":
return n.target in OUTPUT_SAVED_OPS
elif n.op == 'call_module':
elif n.op == "call_module":
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
return False
......
import uuid
import torch
from torch.types import _bool, _device, _dtype
from torch.utils._pytree import tree_flatten, tree_map
from torch.types import _device
from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
__all__ = ['MetaTensor']
__all__ = ["MetaTensor"]
def set_data_ptr(x):
......@@ -43,12 +43,13 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
requires_grad=elem.requires_grad,
) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
set_data_ptr(r._tensor)
return r
......@@ -69,15 +70,15 @@ class MetaTensor(torch.Tensor):
x = x._tensor
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
if "device" in kwargs:
fake_device = kwargs["device"]
kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
......@@ -93,7 +94,7 @@ class MetaTensor(torch.Tensor):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
x = x.to(torch.device("meta"))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
......@@ -120,18 +121,18 @@ class MetaTensor(torch.Tensor):
nonlocal fake_device
if isinstance(x, str) or isinstance(x, _device):
fake_device = x
return 'meta'
return "meta"
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, fake_device=fake_device)
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
if self.device.type == "cpu":
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking)
return self.to(device="cuda:0", non_blocking=non_blocking)
import operator
from typing import Any, List, Union
from typing import Any
import torch
from torch.fx.proxy import Attribute, Proxy
from torch.fx.proxy import Proxy
from colossalai.fx.tracer.meta_patch import meta_patched_function
__all__ = ['ColoProxy']
__all__ = ["ColoProxy"]
class ColoProxy(Proxy):
......@@ -39,11 +38,12 @@ class ColoProxy(Proxy):
return self._meta_data is not None
def _assert_meta_data_is_tensor(self):
assert torch.is_tensor(
self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
assert (
torch.is_tensor(self._meta_data) and self._meta_data.is_meta
), f"Meta data is not a meta tensor for {self.node.name}"
def _assert_has_meta_data(self):
assert self._meta_data is not None, f'Meta data is not set for {self.node.name}'
assert self._meta_data is not None, f"Meta data is not set for {self.node.name}"
def __len__(self):
self._assert_has_meta_data()
......@@ -62,7 +62,6 @@ class ColoProxy(Proxy):
return self.meta_data
def __getattr__(self, k):
return ColoAttribute(self, k)
def __contains__(self, key):
......@@ -92,7 +91,6 @@ def extract_meta(*args, **kwargs):
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
......
......@@ -39,7 +39,7 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
_tensor: torch.Tensor
_node: Node
__slots__ = ['_tensor', '_node']
__slots__ = ["_tensor", "_node"]
@staticmethod
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
......@@ -51,22 +51,22 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
dtype=tensor.dtype,
layout=tensor.layout,
device=fake_device if fake_device is not None else tensor.device,
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
requires_grad=tensor.requires_grad,
) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
name = 'input'
r._node = graph.create_node('placeholder',
'placeholder', (graph._root,),
name=namespace.create_name(name, tensor))
name = "input"
r._node = graph.create_node(
"placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor)
)
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
r._tensor = r._tensor.to(torch.device("meta"))
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaProxy):
......@@ -75,21 +75,21 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
x = x.to(torch.device("meta"))
return x
def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
x = MetaProxy(x, placeholder=True, name='weight')
return x if not hasattr(x, '_node') else x._node
if isinstance(x, torch.Tensor) and not hasattr(x, "_node"):
x = MetaProxy(x, placeholder=True, name="weight")
return x if not hasattr(x, "_node") else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node)
node = graph.create_node("call_function", func, args_node, kwargs_node)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
if "device" in kwargs:
fake_device = kwargs["device"]
kwargs["device"] = torch.device("meta")
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
......@@ -103,9 +103,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaProxy(
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
x = x.to(torch.device("meta"))
return (
MetaProxy(x, fake_device=fake_device)
if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor")
else x
)
def set_node(x):
x._node = node
......@@ -125,9 +128,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
torch.autograd.backward(tensor,
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
retain_graph=True)
grad = (
torch.empty_like(tensor._tensor, device=torch.device("meta"))
if isinstance(tensor, MetaProxy)
else torch.empty_like(tensor, device=torch.device("meta"))
)
torch.autograd.backward(
tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True
)
return graph
......@@ -2,10 +2,10 @@ from typing import Any, List, Union
import torch
from ..proxy import ColoAttribute, ColoProxy
from .meta_patch import meta_patched_function, meta_patched_module
from ..proxy import ColoProxy
from .meta_patch import meta_patched_function
__all__ = ['is_element_in_list', 'extract_meta']
__all__ = ["is_element_in_list", "extract_meta"]
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
......@@ -21,7 +21,6 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
def extract_meta(*args, **kwargs):
def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
......
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
......@@ -10,13 +7,12 @@ from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addbmm)
@bias_addition_function.register(torch.addbmm)
class Addbmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
kwargs = {}
if 'beta' in self.kwargs:
kwargs['beta'] = self.kwargs['beta']
if 'alpha' in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha']
if "beta" in self.kwargs:
kwargs["beta"] = self.kwargs["beta"]
if "alpha" in self.kwargs:
kwargs["alpha"] = self.kwargs["alpha"]
return kwargs
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
......@@ -25,7 +21,7 @@ class Addbmm(LinearBasedBiasFunc):
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.bmm
node_kind = 'call_function'
node_kind = "call_function"
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
......@@ -35,10 +31,10 @@ class Addbmm(LinearBasedBiasFunc):
return non_bias_func_proxy
def insert_sum_node(self, input_proxy, sum_dims=0):
'''
"""
This method is used to sum the input_proxy through the sum_dims.
'''
node_kind = 'call_function'
"""
node_kind = "call_function"
node_target = torch.sum
node_args = (input_proxy, sum_dims)
node_kwargs = {}
......@@ -55,15 +51,15 @@ class Addbmm(LinearBasedBiasFunc):
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs:
beta = kwargs['beta']
if "beta" in kwargs:
beta = kwargs["beta"]
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
if "alpha" in kwargs:
alpha = kwargs["alpha"]
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
else:
......
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
......@@ -10,17 +7,16 @@ from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addmm)
@bias_addition_function.register(torch.addmm)
class Addmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
kwargs = {}
if 'beta' in self.kwargs:
kwargs['beta'] = self.kwargs['beta']
if 'alpha' in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha']
if "beta" in self.kwargs:
kwargs["beta"] = self.kwargs["beta"]
if "alpha" in self.kwargs:
kwargs["alpha"] = self.kwargs["alpha"]
return kwargs
def transpose_other_operand_for_linear(self, other_proxy):
'''
"""
This method is used to transpose the other operand for linear function.
For example:
input = torch.rand(3, 4)
......@@ -30,8 +26,8 @@ class Addmm(LinearBasedBiasFunc):
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
# before we call the linear function.
new_output = torch.linear(m1, m2.transpose(0, 1)) + input
'''
node_kind = 'call_function'
"""
node_kind = "call_function"
node_target = torch.transpose
node_args = (other_proxy, 0, 1)
node_kwargs = {}
......@@ -43,14 +39,14 @@ class Addmm(LinearBasedBiasFunc):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy)
kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs:
beta = kwargs['beta']
if "beta" in kwargs:
beta = kwargs["beta"]
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
if "alpha" in kwargs:
alpha = kwargs["alpha"]
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
else:
alpha_proxy = non_bias_linear_func_proxy
......
......@@ -29,7 +29,6 @@ class BiasAdditionFunc(ABC):
to insert two more operator.mul nodes for the computation graph to compute the
final result.
"""
pass
@abstractmethod
def generate(self):
......@@ -50,7 +49,6 @@ class BiasAdditionFunc(ABC):
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
"""
pass
def create_mul_node(self, input_proxy, coefficent):
"""
......@@ -59,7 +57,7 @@ class BiasAdditionFunc(ABC):
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind = 'call_function'
node_kind = "call_function"
node_target = operator.mul
node_args = (
input_proxy,
......@@ -82,7 +80,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.nn.functional.linear
node_kind = 'call_function'
node_kind = "call_function"
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
......@@ -96,7 +94,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind = 'call_function'
bias_add_node_kind = "call_function"
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
......
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
......@@ -9,17 +6,16 @@ from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_function.register(F.linear)
class Linear(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
assert 'bias' in self.kwargs
assert "bias" in self.kwargs
kwargs = {}
if 'bias' in self.kwargs:
kwargs['bias'] = self.kwargs['bias']
if "bias" in self.kwargs:
kwargs["bias"] = self.kwargs["bias"]
return kwargs
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
kwargs = self.extract_kwargs_from_origin_func()
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs["bias"])
return bias_addition_proxy
......@@ -27,8 +27,8 @@ class BiasAdditionModule(ABC):
Note: this function will be invoked during module initializing,
you should never call this function.
"""
weight_node_kind = 'get_attr'
weight_node_target = self.target + '.weight'
weight_node_kind = "get_attr"
weight_node_target = self.target + ".weight"
weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {})
return weight_proxy
......@@ -39,8 +39,8 @@ class BiasAdditionModule(ABC):
Note: this function will be invoked during module initializing,
you should never call this function.
"""
bias_node_kind = 'get_attr'
bias_node_target = self.target + '.bias'
bias_node_kind = "get_attr"
bias_node_target = self.target + ".bias"
bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {})
return bias_proxy
......@@ -54,14 +54,13 @@ class BiasAdditionModule(ABC):
considered during module initializing. However, we need to consider those attributes as kwargs
in F.conv2d.
"""
pass
def create_non_bias_func_proxy(self, input_proxy=None):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
node_kind = 'call_function'
node_kind = "call_function"
node_target = self.substitute_func
if input_proxy is None:
input_proxy = self.args[0]
......@@ -75,7 +74,7 @@ class BiasAdditionModule(ABC):
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind = 'call_function'
bias_add_node_kind = "call_function"
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
......@@ -100,7 +99,6 @@ class BiasAdditionModule(ABC):
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
"""
pass
module_to_func_dict = {
......
import torch
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple
from torch.nn.modules.utils import _pair, _single, _triple
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
......@@ -10,17 +9,16 @@ from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Conv2d)
@bias_addition_module.register(torch.nn.Conv3d)
class BiasAdditionConv(BiasAdditionModule):
def extract_kwargs_from_mod(self):
root = self.tracer.root
conv_module = root.get_submodule(self.target)
kwarg_attributes = ['groups', 'dilation', 'stride']
kwarg_attributes = ["groups", "dilation", "stride"]
non_bias_kwargs = {}
for attr_name in kwarg_attributes:
if hasattr(conv_module, attr_name):
non_bias_kwargs[attr_name] = getattr(conv_module, attr_name)
if conv_module.padding_mode != "zeros":
#TODO: non zeros mode requires some extra processing for input
# TODO: non zeros mode requires some extra processing for input
conv_type = type(conv_module)
if conv_type == "torch.nn.Conv1d":
padding_element = _single(0)
......@@ -28,9 +26,9 @@ class BiasAdditionConv(BiasAdditionModule):
padding_element = _pair(0)
elif conv_type == "torch.nn.Conv3d":
padding_element = _triple(0)
non_bias_kwargs['padding'] = padding_element
non_bias_kwargs["padding"] = padding_element
else:
non_bias_kwargs['padding'] = getattr(conv_module, 'padding')
non_bias_kwargs["padding"] = getattr(conv_module, "padding")
return non_bias_kwargs
......@@ -41,11 +39,12 @@ class BiasAdditionConv(BiasAdditionModule):
"""
bias_shape = [1] * (dimensions - 1)
bias_shape[0] = -1
bias_reshape_node_kind = 'call_method'
bias_reshape_node_target = 'view'
bias_reshape_node_kind = "call_method"
bias_reshape_node_target = "view"
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
bias_reshape_node_args, {})
bias_reshape_proxy = self.tracer.create_proxy(
bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {}
)
return bias_reshape_proxy
def generate(self):
......
import torch
import torch.nn.functional as F
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
......@@ -7,7 +6,6 @@ from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Linear)
class BiasAdditionLinear(BiasAdditionModule):
def extract_kwargs_from_mod(self):
return {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment