"tests/vscode:/vscode.git/clone" did not exist on "638a07a7f9b504e6c9781e9aa2a9b6c5e9dc49ed"
Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from .._compatibility import compatibility, is_compatible_with_meta
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`.
Returns:
int: The activation size, unit is byte.
"""
act_size = 0
if isinstance(out, torch.Tensor):
if out.is_quantized:
act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size()
else:
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`.
Returns:
int: The parameter size, unit is byte.
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
...@@ -7,6 +7,7 @@ from numbers import Number ...@@ -7,6 +7,7 @@ from numbers import Number
from typing import Any, Callable, List from typing import Any, Callable, List
import torch import torch
from packaging import version
aten = torch.ops.aten aten = torch.ops.aten
...@@ -32,7 +33,7 @@ def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: ...@@ -32,7 +33,7 @@ def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# inputs is a list of length 3. # inputs is a list of length 3.
input_shapes = [v.shape for v in inputs[1:3]] input_shapes = [v.shape for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension] # input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension] # input_shapes[1]: [input feature dimension, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0] assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1] assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0] batch_size, input_dim = input_shapes[0]
...@@ -188,131 +189,136 @@ def zero_flop_jit(*args): ...@@ -188,131 +189,136 @@ def zero_flop_jit(*args):
return 0 return 0
flop_mapping = { if version.parse(torch.__version__) >= version.parse('1.12.0'):
flop_mapping = {
# gemm # gemm
aten.mm.default: matmul_flop_jit, aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit, aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit, aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit, aten.bmm.default: bmm_flop_jit,
# convolution # convolution
aten.convolution.default: conv_flop_jit, aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit,
# normalization # normalization
aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
# pooling # pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0), aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d.default: elementwise_flop_counter(1, 0), aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
aten.avg_pool3d.default: elementwise_flop_counter(1, 0), aten.avg_pool3d.default: elementwise_flop_counter(1, 0),
aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.max_pool1d.default: elementwise_flop_counter(1, 0), aten.max_pool1d.default: elementwise_flop_counter(1, 0),
aten.max_pool2d.default: elementwise_flop_counter(1, 0), aten.max_pool2d.default: elementwise_flop_counter(1, 0),
aten.max_pool3d.default: elementwise_flop_counter(1, 0), aten.max_pool3d.default: elementwise_flop_counter(1, 0),
aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0),
aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0),
aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1),
aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0),
aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1),
aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
aten.embedding.default: elementwise_flop_counter(1, 0), aten.embedding.default: elementwise_flop_counter(1, 0),
} }
elementwise_flop_aten = [ elementwise_flop_aten = [
# basic op # basic op
aten.add.Tensor, aten.add.Tensor,
aten.add_.Tensor, aten.add_.Tensor,
aten.div.Tensor, aten.div.Tensor,
aten.div_.Tensor, aten.div_.Tensor,
aten.div.Scalar, aten.div.Scalar,
aten.div_.Scalar, aten.div_.Scalar,
aten.mul.Tensor, aten.mul.Tensor,
aten.mul.Scalar, aten.mul.Scalar,
aten.mul_.Tensor, aten.mul_.Tensor,
aten.neg.default, aten.neg.default,
aten.pow.Tensor_Scalar, aten.pow.Tensor_Scalar,
aten.rsub.Scalar, aten.rsub.Scalar,
aten.sum.default, aten.sum.default,
aten.sum.dim_IntList, aten.sum.dim_IntList,
aten.mean.dim, aten.mean.dim,
# activation op # activation op
aten.hardswish.default, aten.hardswish.default,
aten.hardswish_.default, aten.hardswish_.default,
aten.hardswish_backward.default, aten.hardswish_backward.default,
aten.hardtanh.default, aten.hardtanh.default,
aten.hardtanh_.default, aten.hardtanh_.default,
aten.hardtanh_backward.default, aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default, aten.hardsigmoid_backward.default,
aten.hardsigmoid.default, aten.hardsigmoid.default,
aten.gelu.default, aten.gelu.default,
aten.gelu_backward.default, aten.gelu_backward.default,
aten.silu.default, aten.silu.default,
aten.silu_.default, aten.silu_.default,
aten.silu_backward.default, aten.silu_backward.default,
aten.sigmoid.default, aten.sigmoid.default,
aten.sigmoid_backward.default, aten.sigmoid_backward.default,
aten._softmax.default, aten._softmax.default,
aten._softmax_backward_data.default, aten._softmax_backward_data.default,
aten.relu_.default, aten.relu_.default,
aten.relu.default, aten.relu.default,
aten.tanh.default, aten.tanh.default,
aten.tanh_backward.default, aten.tanh_backward.default,
aten.threshold_backward.default, aten.threshold_backward.default,
# dropout # dropout
aten.native_dropout.default, aten.native_dropout.default,
aten.native_dropout_backward.default, aten.native_dropout_backward.default,
] ]
for op in elementwise_flop_aten:
for op in elementwise_flop_aten: flop_mapping[op] = elementwise_flop_counter(1, 0)
flop_mapping[op] = elementwise_flop_counter(1, 0)
# TODO: this will be removed in future
# TODO: this will be removed in future zero_flop_aten = [
zero_flop_aten = [ aten.as_strided.default,
aten.as_strided.default, aten.as_strided_.default,
aten.as_strided_.default, aten.bernoulli_.float,
aten.bernoulli_.float, aten.cat.default,
aten.cat.default, aten.clone.default,
aten.clone.default, aten.copy_.default,
aten.copy_.default, aten.detach.default,
aten.detach.default, aten.expand.default,
aten.expand.default, aten.empty_like.default,
aten.empty_like.default, aten.new_empty.default,
aten.new_empty.default, aten.new_empty_strided.default,
aten.new_empty_strided.default, aten.ones_like.default,
aten.ones_like.default, aten._reshape_alias.default,
aten._reshape_alias.default, aten.select.int,
aten.select.int, aten.select_backward.default,
aten.select_backward.default, aten.squeeze.dim,
aten.squeeze.dim, aten.slice.Tensor,
aten.slice.Tensor, aten.slice_backward.default,
aten.slice_backward.default, aten.split.Tensor,
aten.split.Tensor, aten.permute.default,
aten.permute.default, aten.t.default,
aten.t.default, aten.transpose.int,
aten.transpose.int, aten._to_copy.default,
aten._to_copy.default, aten.unsqueeze.default,
aten.unsqueeze.default, aten.unbind.int,
aten.unbind.int, aten._unsafe_view.default,
aten._unsafe_view.default, aten.view.default,
aten.view.default, aten.where.self,
aten.where.self, aten.zero_.default,
aten.zero_.default, aten.zeros_like.default,
aten.zeros_like.default, ]
]
for op in zero_flop_aten:
for op in zero_flop_aten: flop_mapping[op] = zero_flop_jit
flop_mapping[op] = zero_flop_jit
else:
flop_mapping = {}
elementwise_flop_aten = {}
zero_flop_aten = {}
...@@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map ...@@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map
from .._compatibility import compatibility from .._compatibility import compatibility
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory import activation_size, parameter_size from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping from .opcount import flop_mapping
from .tensor import MetaTensor from .tensor import MetaTensor
...@@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G ...@@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def pack(x): def pack(x):
global cache, do_not_cache global cache, do_not_cache
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache: if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach() tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid tensor.data_ptr = x._tensor.data_ptr
x._node.meta['saved_tensor'] += [tensor] x._node.meta['saved_tensor'] += [tensor]
if not do_not_cache: if not do_not_cache:
cache.add(x._tensor.uuid) cache.add(x._tensor.data_ptr())
return x return x
def unpack(x): def unpack(x):
...@@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G ...@@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def extract_tensor(x: Any): def extract_tensor(x: Any):
if isinstance(x, MetaTensor): if isinstance(x, MetaTensor):
tensor = x._tensor.detach() tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid tensor.data_ptr = x._tensor.data_ptr
return tensor return tensor
if not isinstance(x, torch.finfo): if not isinstance(x, torch.finfo):
return x return x
...@@ -286,13 +286,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G ...@@ -286,13 +286,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable: def profile_function(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
Warnings: Warnings:
You may only use tensors with `device=meta` for this wrapped function. You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available. Only original `torch.nn.functional` are available.
Examples: Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta') >>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu >>> func = torch.nn.functional.relu
...@@ -328,6 +328,8 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: ...@@ -328,6 +328,8 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
out, meta = _profile_concrete(func, *args, **kwargs) out, meta = _profile_concrete(func, *args, **kwargs)
if inplace: if inplace:
kwargs['inplace'] = True kwargs['inplace'] = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False do_not_cache = False
meta.bwd_mem_out -= param_size meta.bwd_mem_out -= param_size
...@@ -342,7 +344,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: ...@@ -342,7 +344,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
def profile_method(target: 'Target', device: str = 'meta') -> Callable: def profile_method(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_method` node Wrap a `call_method` node
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
""" """
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
...@@ -360,13 +362,13 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: ...@@ -360,13 +362,13 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
""" """
Wrap a `call_module` node or `torch.nn` in order to Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
Warnings: Warnings:
You may only use tensors with `device=meta` for this wrapped function. You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available. Only original `torch.nn` are available.
Example: Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta') >>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3) >>> mod = torch.nn.Conv2d(3, 128, 3)
...@@ -394,6 +396,8 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: ...@@ -394,6 +396,8 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
out, meta = _profile_concrete(func, *args, **kwargs) out, meta = _profile_concrete(func, *args, **kwargs)
if inplace: if inplace:
module.inplace = True module.inplace = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False do_not_cache = False
# grad for param will not be counted # grad for param will not be counted
......
from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx import GraphModule, Node from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta from .._compatibility import compatibility, is_compatible_with_meta
from .memory_utils import activation_size
if is_compatible_with_meta(): if is_compatible_with_meta():
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
__all__ = [ __all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
]
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
act_size = 0
if isinstance(out, torch.Tensor):
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The parameter size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
@compatibility(is_backward_compatible=False)
def calculate_fwd_in(n: Node) -> int: def calculate_fwd_in(n: Node) -> int:
"""A helper function to calculate `fwd_in` """A helper function to calculate `fwd_in` (with sharding spec)
Args: Args:
n (Node): a node from the graph n (Node): a node from the graph
...@@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int: ...@@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int:
Returns: Returns:
fwd_in (int): the result of `fwd_in` fwd_in (int): the result of `fwd_in`
""" """
# TODO(super-dainiu): should divide the memory by sharding spec
return activation_size(n.meta["fwd_in"]) return activation_size(n.meta["fwd_in"])
@compatibility(is_backward_compatible=False)
def calculate_fwd_tmp(n: Node) -> int: def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp` """A helper function to calculate `fwd_tmp` (with sharding spec)
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy. Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
Args: Args:
...@@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int: ...@@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int:
fwd_tmp (int): the result of `fwd_tmp` fwd_tmp (int): the result of `fwd_tmp`
""" """
# TODO(super-dainiu): should divide the memory by sharding spec
def is_relu_like_node(n: Node) -> bool: def is_relu_like_node(n: Node) -> bool:
"""Check if a node is a ReLU-like node. """Check if a node is a ReLU-like node.
ReLU-like nodes have the following properties: ReLU-like nodes have the following properties:
...@@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int: ...@@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int:
return 0 return 0
@compatibility(is_backward_compatible=False)
def calculate_fwd_out(n: Node) -> int: def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out` """A helper function to calculate `fwd_out` (with sharding spec)
Args: Args:
n (Node): a node from the graph n (Node): a node from the graph
...@@ -117,33 +81,34 @@ def calculate_fwd_out(n: Node) -> int: ...@@ -117,33 +81,34 @@ def calculate_fwd_out(n: Node) -> int:
fwd_out (int): the result of `fwd_out` fwd_out (int): the result of `fwd_out`
""" """
# TODO(super-dainiu): should divide the memory by sharding spec
def intersect(a, b): def intersect(a, b):
return {k: a[k] for k in a if k in b} return {k: a[k] for k in a if k in b}
fwd_in = dict() fwd_in = dict()
for u in n.users: for u in n.users:
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}) fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)})
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)}
return activation_size(intersect(fwd_in, fwd_out)) return activation_size(intersect(fwd_in, fwd_out))
def is_inplace(n: Node): def calculate_fwd_time(n: Node) -> float:
"""Get the inplace argument from torch.fx.Node """A helper function to calculate `fwd_time` (with sharding spec)
Args: Args:
node (Node): torch.fx.Node n (Node): a node from the graph
Returns:
fwd_time (float): the result of `fwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["fwd_time"]
def calculate_bwd_time(n: Node) -> float:
"""A helper function to calculate `bwd_time` (with sharding spec)
Args:
n (Node): a node from the graph
Returns: Returns:
bool: indicates whether this op is inplace bwd_time (float): the result of `bwd_time`
""" """
inplace = False # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
if n.op == "call_function": return n.meta["bwd_time"]
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
...@@ -12,10 +12,11 @@ from .constants import ALIAS_ATEN ...@@ -12,10 +12,11 @@ from .constants import ALIAS_ATEN
__all__ = ['MetaTensor'] __all__ = ['MetaTensor']
def set_uuid(x): def set_data_ptr(x):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
if not hasattr(x, 'uuid'): if not x.data_ptr():
setattr(x, 'uuid', uuid.uuid4()) data_ptr = uuid.uuid4()
x.data_ptr = lambda: data_ptr
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
...@@ -53,7 +54,7 @@ class MetaTensor(torch.Tensor): ...@@ -53,7 +54,7 @@ class MetaTensor(torch.Tensor):
if not r._tensor.is_meta: if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta')) r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta` # only tensor not on `meta` should be copied to `meta`
set_uuid(r._tensor) set_data_ptr(r._tensor)
return r return r
def __repr__(self): def __repr__(self):
...@@ -88,7 +89,7 @@ class MetaTensor(torch.Tensor): ...@@ -88,7 +89,7 @@ class MetaTensor(torch.Tensor):
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy # here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
# of the input # of the input
if func in ALIAS_ATEN: if func in ALIAS_ATEN:
setattr(out, 'uuid', args[0].uuid) out.data_ptr = args[0].data_ptr
# Now, we want to continue propagating this tensor, so we rewrap Tensors in # Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass # our custom tensor subclass
...@@ -127,3 +128,13 @@ class MetaTensor(torch.Tensor): ...@@ -127,3 +128,13 @@ class MetaTensor(torch.Tensor):
if device is not None: if device is not None:
result = MetaTensor(result, fake_device=device) result = MetaTensor(result, fake_device=device)
return result return result
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
def cuda(self, *args, **kwargs):
if self.device.type == 'cuda':
return self.to(*args, **kwargs)
return self.to(*args, device='cuda', **kwargs)
from .tracer import ColoTracer from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem
from ._meta_trace import meta_trace
from ._meta_trace import meta_trace
from ._symbolic_trace import symbolic_trace
from .tracer import ColoTracer
from colossalai.fx.profiler.memory import activation_size
import torch import torch
from torch.fx import Node, Graph from torch.fx import Graph, Node
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
......
from typing import Any, Callable, Dict, Optional, Union
import torch
from colossalai.fx import ColoGraphModule
from colossalai.fx._compatibility import compatibility
from .tracer import ColoTracer
@compatibility(is_backward_compatible=True)
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
) -> ColoGraphModule:
"""
Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
constructed by recording operations seen while tracing through ``root``.
With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using
``meta_args`` only, the tracing can be done ahead of time.
Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the
argument's values.
Uses:
>>> model = ...
# if this works
>>> gm = symbolic_trace(model, concrete_args=concrete_args)
# else try this
>>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing.
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
Defaults to None.
Returns:
ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``.
Warnings:
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
from .patched_bias_addition_function import *
from .patched_bias_addition_module import *
from .addbmm import Addbmm
from .addmm import Addmm
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict
from .linear import Linear
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addbmm)
@bias_addition_function.register(torch.addbmm)
class Addbmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
kwargs = {}
if 'beta' in self.kwargs:
kwargs['beta'] = self.kwargs['beta']
if 'alpha' in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha']
return kwargs
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.bmm
node_kind = 'call_function'
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
# torch.bmm does not have any kwargs
node_kwargs = {}
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return non_bias_func_proxy
def insert_sum_node(self, input_proxy, sum_dims=0):
'''
This method is used to sum the input_proxy through the sum_dims.
'''
node_kind = 'call_function'
node_target = torch.sum
node_args = (input_proxy, sum_dims)
node_kwargs = {}
sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return sum_proxy
def generate(self):
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2])
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs:
beta = kwargs['beta']
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
else:
alpha_proxy = sum_proxy
# doing the addition(temp_4 = temp_2 + temp_3)
bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy)
return bias_addition_proxy
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addmm)
@bias_addition_function.register(torch.addmm)
class Addmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
kwargs = {}
if 'beta' in self.kwargs:
kwargs['beta'] = self.kwargs['beta']
if 'alpha' in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha']
return kwargs
def transpose_other_operand_for_linear(self, other_proxy):
'''
This method is used to transpose the other operand for linear function.
For example:
input = torch.rand(3, 4)
m1 = torch.rand(3, 5)
m2 = torch.rand(5, 4)
original_output = torch.addmm(input, m1, m2)
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
# before we call the linear function.
new_output = torch.linear(m1, m2.transpose(0, 1)) + input
'''
node_kind = 'call_function'
node_target = torch.transpose
node_args = (other_proxy, 0, 1)
node_kwargs = {}
transpose_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return transpose_proxy
def generate(self):
transpose_proxy = self.transpose_other_operand_for_linear(self.args[2])
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy)
kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs:
beta = kwargs['beta']
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
else:
alpha_proxy = non_bias_linear_func_proxy
bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy)
return bias_addition_proxy
import operator
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
class BiasAdditionFunc(ABC):
"""
This class is used to construct the restructure computation graph for
call_func node with bias addition inside.
"""
def __init__(self, tracer, target, args, kwargs, substitute_func):
self.tracer = tracer
self.target = target
self.args = args
self.kwargs = kwargs
self.substitute_func = substitute_func
@abstractmethod
def extract_kwargs_from_origin_func(self):
"""
This method is used to extract the kwargs for further graph transform.
For example:
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
The kwargs for addmm function is {beta=1, alpha=1, output=None}, then we need
to insert two more operator.mul nodes for the computation graph to compute the
final result.
"""
pass
@abstractmethod
def generate(self):
"""
This method is used to construct the whole restructure computation graph for call_func node with bias
addition inside.
A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node,
a bias reshape node if needed and a bias addition node.
Use torch.addmm as an example:
The origin node is:
%addmm: call_func[target=torch.addmm](args = (%input_1, m1, m2), kwargs = {beta=1, alpha=1})
Restructured graph is:
%transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {})
%linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {})
%mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {})
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
"""
pass
def create_mul_node(self, input_proxy, coefficent):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind = 'call_function'
node_target = operator.mul
node_args = (
input_proxy,
coefficent,
)
node_kwargs = {}
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return mul_proxy
class LinearBasedBiasFunc(BiasAdditionFunc):
"""
This class is used to construct the restructure computation graph for
call_func node based on F.linear.
"""
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.nn.functional.linear
node_kind = 'call_function'
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
# non-bias linear does not have any kwargs
node_kwargs = {}
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return non_bias_func_proxy
def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy):
"""
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind = 'call_function'
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
return bias_add_proxy
func_to_func_dict = {
torch.addmm: F.linear,
torch.addbmm: torch.bmm,
F.linear: F.linear,
}
method_to_func_dict = {
torch.Tensor.addmm: F.linear,
torch.Tensor.addbmm: torch.bmm,
}
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_function.register(F.linear)
class Linear(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
assert 'bias' in self.kwargs
kwargs = {}
if 'bias' in self.kwargs:
kwargs['bias'] = self.kwargs['bias']
return kwargs
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
kwargs = self.extract_kwargs_from_origin_func()
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
return bias_addition_proxy
from .bias_addition_module import *
from .conv import *
from .linear import *
import operator
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
class BiasAdditionModule(ABC):
"""
This class is used to construct the restructure computation graph for
call_module node with bias addition inside.
"""
def __init__(self, tracer, target, args, kwargs, substitute_func):
self.tracer = tracer
self.target = target
self.args = args
self.kwargs = kwargs
self.substitute_func = substitute_func
self.weight_proxy = self._create_weight_proxy()
self.bias_proxy = self._create_bias_proxy()
def _create_weight_proxy(self):
"""
Create weight proxy, the node created by this proxy contains module weight.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
weight_node_kind = 'get_attr'
weight_node_target = self.target + '.weight'
weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {})
return weight_proxy
def _create_bias_proxy(self):
"""
Create bias proxy, the node created by this proxy contains module bias.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
bias_node_kind = 'get_attr'
bias_node_target = self.target + '.bias'
bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {})
return bias_proxy
@abstractmethod
def extract_kwargs_from_mod(self):
"""
This method is used to extract the kwargs for non-bias computation.
For example:
The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are
considered during module initilizing. However, we need to consider those attributes as kwargs
in F.conv2d.
"""
pass
def create_non_bias_func_proxy(self, input_proxy=None):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
node_kind = 'call_function'
node_target = self.substitute_func
if input_proxy is None:
input_proxy = self.args[0]
node_args = (input_proxy, self.weight_proxy)
node_kwargs = self.extract_kwargs_from_mod()
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return non_bias_func_proxy
def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy):
"""
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind = 'call_function'
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
return bias_add_proxy
@abstractmethod
def generate(self):
"""
This method is used to construct the whole restructure computation graph for call_module node with bias
addition inside.
A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node,
a bias reshape node if needed and a bias addition node.
Use Conv2d module as an example:
The origin node is:
%conv: call_module[target=conv](args = (%x,), kwargs = {})
Restructured graph is:
%conv_weight : [#users=1] = get_attr[target=conv.weight]
%conv_bias : [#users=1] = get_attr[target=conv.bias]
%conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
"""
pass
module_to_func_dict = {
torch.nn.Linear: F.linear,
torch.nn.Conv1d: F.conv1d,
torch.nn.Conv2d: F.conv2d,
torch.nn.Conv3d: F.conv3d,
}
import torch
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Conv1d)
@bias_addition_module.register(torch.nn.Conv2d)
@bias_addition_module.register(torch.nn.Conv3d)
class BiasAdditionConv(BiasAdditionModule):
def extract_kwargs_from_mod(self):
root = self.tracer.root
conv_module = root.get_submodule(self.target)
kwarg_attributes = ['groups', 'dilation', 'stride']
non_bias_kwargs = {}
for attr_name in kwarg_attributes:
if hasattr(conv_module, attr_name):
non_bias_kwargs[attr_name] = getattr(conv_module, attr_name)
if conv_module.padding_mode != "zeros":
#TODO: non zeros mode requires some extra processing for input
conv_type = type(conv_module)
if conv_type == "torch.nn.Conv1d":
padding_element = _single(0)
elif conv_type == "torch.nn.Conv2d":
padding_element = _pair(0)
elif conv_type == "torch.nn.Conv3d":
padding_element = _triple(0)
non_bias_kwargs['padding'] = padding_element
else:
non_bias_kwargs['padding'] = getattr(conv_module, 'padding')
return non_bias_kwargs
def create_bias_reshape_proxy(self, dimensions):
"""
This method is used to reshape the bias node in order to make bias and
output of non-bias convolution broadcastable.
"""
bias_shape = [1] * (dimensions - 1)
bias_shape[0] = -1
bias_reshape_node_kind = 'call_method'
bias_reshape_node_target = 'view'
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
bias_reshape_node_args, {})
return bias_reshape_proxy
def generate(self):
non_bias_conv_func_proxy = self.create_non_bias_func_proxy()
output_dims = non_bias_conv_func_proxy.meta_data.dim()
bias_reshape_proxy = self.create_bias_reshape_proxy(output_dims)
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_conv_func_proxy, bias_reshape_proxy)
return bias_addition_proxy
import torch
import torch.nn.functional as F
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Linear)
class BiasAdditionLinear(BiasAdditionModule):
def extract_kwargs_from_mod(self):
return {}
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy()
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, self.bias_proxy)
return bias_addition_proxy
import enum
import functools
import operator
import inspect
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from torch.fx import Graph, Node, Proxy, Tracer
from torch.utils._pytree import tree_map
from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta
from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list
from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from colossalai.fx.tracer.registry import (
bias_addition_function,
bias_addition_method,
bias_addition_module,
meta_patched_function,
meta_patched_module,
)
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
Target = Union[Callable[..., Any], str]
Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
List[Any], # actually Argument
Dict[str, Any], # actually Argument
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
'Node',]]
_CScriptMethod = ['add', 'mul', 'sub', 'div']
_TorchNewMethod = [
"arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor",
"finfo"
]
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
def _truncate_suffix(s: str):
import re
return re.sub(r'_\d+$', '', s)
def default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
@compatibility(is_backward_compatible=False)
class ColoProxy(Proxy):
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs)
self._meta_data = data
@property
def meta_data(self):
return self._meta_data
@meta_data.setter
def meta_data(self, args):
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
self._meta_data = tree_map(wrap_fn, args)
@classmethod
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
kwargs = {} if kwargs is None else kwargs
if proxy.meta_data is None:
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
return proxy
@classmethod
def from_torch_proxy(cls, proxy: Proxy):
return cls(proxy.node, proxy.tracer)
def __repr__(self):
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
def __len__(self):
return len(self.meta_data)
def __int__(self):
return int(self.meta_data)
def __index__(self):
try:
return int(self.meta_data)
except:
return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
def __float__(self):
return float(self.meta_data)
def __bool__(self):
return self.meta_data
def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return False
return super().__contains__(key)
def __isinstancecheck__(self, type):
return isinstance(self.meta_data, type)
@property
def shape(self):
return self.meta_data.shape
@property
def ndim(self):
return self.meta_data.ndim
@property
def device(self):
proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
proxy.meta_data = self.meta_data.device
return proxy
@property
def dtype(self):
proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
proxy.meta_data = self.meta_data.dtype
return proxy
def to(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
def cpu(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
def cuda(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
@compatibility(is_backward_compatible=False)
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._meta_data = data
self._node: Optional[Node] = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
@compatibility(is_backward_compatible=False)
class ColoTracer(Tracer):
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self._disable_module_getattr = False
self.proxy_buffer_attributes = True
# whether the tracer will record the usage of torch.utils.checkpoint
self.trace_act_ckpt = trace_act_ckpt
# whether the current tracing occurs within the activation checkpoint functions
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count = 0
def proxy(self, node: Node) -> 'ColoProxy':
return ColoProxy(node, self)
def create_proxy(self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if kind == 'placeholder':
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
_truncate_suffix(target), None)
elif kind == 'get_attr':
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
proxy.meta_data = attr_itr
finally:
self._disable_module_getattr = False
elif kind == 'call_function':
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method':
self._disable_module_getattr = True
try:
if target == '__call__':
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
finally:
self._disable_module_getattr = False
elif kind == 'call_module':
mod = self.root.get_submodule(target)
self._disable_module_getattr = True
try:
proxy.meta_data = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
finally:
self._disable_module_getattr = False
return proxy
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
return node
def trace(self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
if k in non_meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names
def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
self.concrete_args = concrete_args
self.meta_args = meta_args
with _TorchTensorOverride(self), self.trace_activation_checkpoint(enabled=self.trace_act_ckpt):
self.graph = super().trace(root, concrete_args=concrete_args)
self.graph.lint()
return self.graph
@contextmanager
def trace_activation_checkpoint(self, enabled: bool):
if enabled:
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activaton checkpoint part
self.inside_torch_checkpoint_func = True
out = run_function(*args)
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count += 1
return out
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
"We do not implement the backward pass as we only trace the forward pass.")
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
yield
if enabled:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
self.graph.lint()
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
lambda node: ColoProxy(self, node, n, attr_val))
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
return attr_val
@compatibility(is_backward_compatible=True)
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
) -> ColoGraphModule:
if is_compatible_with_meta():
if meta_args is not None:
root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args))
root.cpu()
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
else:
from .tracer import ColoTracer as OrigColoTracer
graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
@compatibility(is_backward_compatible=False)
class _TorchTensorOverride(object):
def __init__(self, tracer: Tracer):
self.overrides = {}
self.tracer = tracer
def __enter__(self):
def wrap_tensor_method(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
isinstance(p, ColoProxy) for p in kwargs.values())
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.tracer._disable_module_getattr = True
try:
proxy = self.tracer.create_proxy('call_function', target, args, kwargs)
finally:
self.tracer._disable_module_getattr = False
return proxy
else:
return target(*args, **kwargs)
return wrapper, target
self.overrides = {
target: wrap_tensor_method(getattr(torch, target))
for target in _TorchNewMethod
if callable(getattr(torch, target))
}
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, wrapper)
def __exit__(self, exc_type, exc_val, exc_tb):
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, orig)
def meta_prop_pass(gm: ColoGraphModule,
root: torch.nn.Module,
meta_args: Optional[Dict[str, Any]] = None,
concrete_args: Optional[Dict[str, torch.Tensor]] = None):
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
if k in non_meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
for node in gm.graph.nodes:
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
node.kwargs)
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
if kind == 'placeholder':
meta_out = meta_args[target] if target in meta_args else concrete_args.get(
_truncate_suffix(target), None)
elif kind == 'get_attr':
attr_itr = root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
meta_out = attr_itr
elif kind == 'call_function':
meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method':
if target == '__call__':
meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
elif kind == 'call_module':
mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
else:
meta_out = None
return meta_out
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
meta_out = meta_args[target]
return meta_out
if target in [getattr(torch, torch_func) for torch_func in _TorchNewMethod]:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if "device" in kwargs:
kwargs["device"] = "meta"
try:
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
args_metas = tree_map(unwrap_fn, args)
kwargs_metas = tree_map(unwrap_fn, kwargs)
if kind == "call_function":
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
# fetch patched method
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
else:
meta_target = method
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_module":
mod = root.get_submodule(target)
mod_type = type(mod)
if meta_patched_module.has(mod_type):
meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas)
else:
meta_out = mod(*args_metas, **kwargs_metas)
elif kind == "get_attr":
attr_itr = root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.parameter.Parameter):
meta_out = torch.nn.Parameter(attr_itr.to(device="meta"))
elif isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
else:
return None
except Exception as e:
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
return meta_out
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None):
result_graph = Graph()
value_remap = {}
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
for orig_node in gm.graph.nodes:
assert hasattr(orig_node, "_meta_data")
kind = orig_node.op
target = orig_node.target
args = orig_node.args
kwargs = orig_node.kwargs
args_metas = tree_map(unwrap_fn, args)
tracer = ColoTracer()
tracer.graph = Graph(tracer_cls=ColoTracer)
tracer.root = root_model
def wrap_fn(n):
if isinstance(n, Node):
proxy = ColoProxy(n, tracer)
proxy.meta_data = n._meta_data
return proxy
return n
args_proxy = tree_map(wrap_fn, args)
kwargs_proxy = tree_map(wrap_fn, kwargs)
handle = None
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
else:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
elif kind == "call_module":
# if not hasattr(self, "orig_forward"):
# raise AttributeError(f"{self} does not have an attribute called orig_forward")
mod = gm.get_submodule(target)
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
if handle is not None:
handle.generate()
for node_inserted in tracer.graph.nodes:
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n])
last_node = value_remap[node_inserted]
value_remap[orig_node] = last_node
else:
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n])
del tracer
gm.graph = result_graph
gm.recompile()
meta_prop_pass(gm, root_model, meta_args)
from .registry import *
from .patched_function import * from .patched_function import *
from .patched_module import * from .patched_module import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment