Unverified Commit 393f5940 authored by Super Daniel's avatar Super Daniel Committed by GitHub
Browse files

[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions...

[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)

* [fx] move meta registration

* [fx] fix tests.

* [fx] fix test.

* [fx] fix.

* [meta] refactor meta registration.py.

* [fx] add compatibility descriptions.

* [fx] polish import.

* [fx] add a decorator.

* [fx] fix tests.

* [fx] remove print.

* [fx] edit raise error.

* [fx] edit raise error.

* [fx] add type hint.

* [fx] fix import in experimental.

* [rpc] remove color debug.

* [meta] fix naming.
parent e8d8eda5
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch, from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser) get_default_parser)
......
from .tracer import ColoTracer, meta_trace from ._compatibility import compatibility, is_compatible_with_meta
from .graph_module import ColoGraphModule from .graph_module import ColoGraphModule
from .passes import MetaInfoProp from .passes import MetaInfoProp
from .tracer import ColoTracer, meta_trace
from typing import Callable
import torch
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
META_COMPATIBILITY = False
def compatibility(is_backward_compatible: bool = False) -> Callable:
"""A decorator to make a function compatible with different versions of PyTorch.
Args:
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False.
Returns:
Callable: The decorated function
"""
def decorator(func):
if META_COMPATIBILITY:
return func
else:
if is_backward_compatible:
return func
else:
def wrapper(*args, **kwargs):
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
return wrapper
return decorator
def is_compatible_with_meta() -> bool:
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx`
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its
experimental counterparts.
Returns:
bool: The meta compatibility
"""
return META_COMPATIBILITY
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py # meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
# should be activated for PyTorch version 1.12.0 and below # should be activated for PyTorch version 1.12.0 and below
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -31,6 +34,7 @@ def register_meta(op, register_dispatcher=True): ...@@ -31,6 +34,7 @@ def register_meta(op, register_dispatcher=True):
return wrapper return wrapper
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834 # https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default) @register_meta(aten.convolution.default)
def meta_conv( def meta_conv(
...@@ -165,6 +169,18 @@ def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: t ...@@ -165,6 +169,18 @@ def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: t
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return grad_input
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
@register_meta(aten.relu.default) @register_meta(aten.relu.default)
def meta_relu(input: torch.Tensor): def meta_relu(input: torch.Tensor):
return torch.empty_like(input) return torch.empty_like(input)
...@@ -192,11 +208,8 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: ...@@ -192,11 +208,8 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
return grad_in return grad_in
@register_meta(aten.roll.default) # ============================== Normalization =====================================
def meta_roll(input: torch.Tensor, shifts, dims): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
return input
@register_meta(aten.native_batch_norm.default) @register_meta(aten.native_batch_norm.default)
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1) n_input = input.size(1)
...@@ -207,6 +220,7 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini ...@@ -207,6 +220,7 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
return output, running_mean, running_var return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default) @register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask): save_invstd, train, eps, output_mask):
...@@ -241,6 +255,7 @@ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch. ...@@ -241,6 +255,7 @@ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.
return dX, dgamma, dbeta return dX, dgamma, dbeta
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default) @register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs = input.size(0) bs = input.size(0)
...@@ -252,6 +267,7 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): ...@@ -252,6 +267,7 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
return output, running_mean, running_var return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default) @register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask): grad_input_mask):
...@@ -261,13 +277,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me ...@@ -261,13 +277,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
return dX, dgamma, dbeta return dX, dgamma, dbeta
@register_meta(aten._adaptive_avg_pool2d_backward.default) # ================================== Misc ==========================================
def meta_adaptive_avg_pool2d_backward( #https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
grad_output: torch.Tensor, @register_meta(aten.roll.default)
input: torch.Tensor, def meta_roll(input: torch.Tensor, shifts, dims):
): return input
grad_input = torch.empty_like(input)
return torch.empty_like(input)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
@register_meta(aten.index.Tensor) @register_meta(aten.index.Tensor)
...@@ -360,6 +381,8 @@ def meta_index_Tensor(self, indices): ...@@ -360,6 +381,8 @@ def meta_index_Tensor(self, indices):
return self.new_empty(before_shape + replacement_shape + after_shape) return self.new_empty(before_shape + replacement_shape + after_shape)
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default) @register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq): scale_grad_by_freq):
...@@ -369,13 +392,7 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens ...@@ -369,13 +392,7 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
layout=grad_output.layout) layout=grad_output.layout)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp # ============================== Dropout ===========================================
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default) @register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
......
from typing import List, Tuple
import copy import copy
import math
from typing import List, Tuple
import torch import torch
from torch.fx import GraphModule, Node from colossalai.fx import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import \
_find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import parameter_size from colossalai.fx.profiler import parameter_size
import math from torch.fx import GraphModule, Node
from .linearize import linearize from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
from colossalai.fx.passes.meta_info_prop import MetaInfoProp Sequence)
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
from colossalai import META_COMPATIBILITY
INF = float("inf") INF = float("inf")
...@@ -508,7 +512,7 @@ def solver_pofo(gm: ColoGraphModule, ...@@ -508,7 +512,7 @@ def solver_pofo(gm: ColoGraphModule,
mem_limit -= parameter_size(gm) mem_limit -= parameter_size(gm)
# prepare data # prepare data
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device) data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
......
from dataclasses import asdict from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import torch import torch
import torch.fx import torch.fx
from torch.fx.node import Node, Argument, Target from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from typing import Any, List, Tuple, NamedTuple, Dict, Optional
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
from torch.fx.graph_module import GraphModule
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
......
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple
import torch import torch
import torch.fx import torch.fx
from torch.fx.node import Node, Argument, Target from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from typing import Any, List, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
......
from ... import META_COMPATIBILITY from .._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .opcount import flop_mapping from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module from .profiler import profile_function, profile_method, profile_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .tensor import MetaTensor
else: else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo from .dataflow import GraphInfo
from .memory import parameter_size, activation_size, is_inplace from .memory import activation_size, is_inplace, parameter_size
import torch
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY
__all__ = []
if META_COMPATIBILITY:
aten = torch.ops.aten
ALIAS_ATEN = [
# inplace reshaping
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]
INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]
CLONE_ATEN = [
aten.clone.default,
]
__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
import torch
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
aten = torch.ops.aten
ALIAS_ATEN = [
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]
INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]
CLONE_ATEN = [
aten.clone.default,
]
...@@ -2,7 +2,10 @@ from dataclasses import dataclass, field ...@@ -2,7 +2,10 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Dict, List from typing import Dict, List
from torch.fx import Graph, Node from torch.fx import Graph, Node
from .._compatibility import compatibility
from .memory import activation_size, is_inplace from .memory import activation_size, is_inplace
...@@ -12,6 +15,7 @@ class Phase(Enum): ...@@ -12,6 +15,7 @@ class Phase(Enum):
PLACEHOLDER = 2 PLACEHOLDER = 2
@compatibility(is_backward_compatible=True)
@dataclass @dataclass
class GraphInfo: class GraphInfo:
""" """
...@@ -69,6 +73,7 @@ def is_phase(n: Node, phase: Phase) -> bool: ...@@ -69,6 +73,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
return n.meta['phase'] == phase return n.meta['phase'] == phase
@compatibility(is_backward_compatible=False)
def autograd_graph_analysis(graph: Graph) -> GraphInfo: def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage. """Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked for keyword `phase`. Basically the input graph should have all nodes marked for keyword `phase`.
......
from .registry import meta_profiler_function, meta_profiler_module from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .profiler import profile_function, profile_method, profile_module
from .profiler_function import * from .profiler_function import *
from .profiler_module import * from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module from .registry import meta_profiler_function, meta_profiler_module
from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub
import torch
__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
# for PyTorch 1.11 compatibility uses # for PyTorch 1.11 compatibility uses
from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx import Node, GraphModule from torch.fx import GraphModule, Node
from typing import Union, Dict, List, Tuple
from ..._compatibility import compatibility
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] __all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_in(n: Node) -> bool: def calculate_fwd_in(n: Node) -> bool:
"""A helper function to calculate `fwd_in` """A helper function to calculate `fwd_in`
...@@ -18,6 +22,7 @@ def calculate_fwd_in(n: Node) -> bool: ...@@ -18,6 +22,7 @@ def calculate_fwd_in(n: Node) -> bool:
return n.meta['save_fwd_in'] return n.meta['save_fwd_in']
@compatibility(is_backward_compatible=True)
def calculate_fwd_tmp(n: Node) -> int: def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp` """A helper function to calculate `fwd_tmp`
...@@ -30,6 +35,7 @@ def calculate_fwd_tmp(n: Node) -> int: ...@@ -30,6 +35,7 @@ def calculate_fwd_tmp(n: Node) -> int:
return n.meta["fwd_mem_tmp"] return n.meta["fwd_mem_tmp"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_out(n: Node) -> int: def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out` """A helper function to calculate `fwd_out`
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Any, Dict, Tuple from typing import Any, Callable, Dict, Tuple
import torch import torch
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module
from ..._compatibility import compatibility
from ..memory import activation_size from ..memory import activation_size
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
__all__ = ['profile_function', 'profile_module', 'profile_method'] __all__ = ['profile_function', 'profile_module', 'profile_method']
# this is for compatibility use # this is for compatibility use
@compatibility(is_backward_compatible=True)
@dataclass @dataclass
class GraphInfo: class GraphInfo:
""" """
...@@ -69,6 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int ...@@ -69,6 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
""" """
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable: def profile_function(target: 'Target') -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
...@@ -106,6 +111,7 @@ def profile_function(target: 'Target') -> Callable: ...@@ -106,6 +111,7 @@ def profile_function(target: 'Target') -> Callable:
return f return f
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target') -> Callable: def profile_method(target: 'Target') -> Callable:
""" """
Wrap a `call_method` node Wrap a `call_method` node
...@@ -133,6 +139,7 @@ def profile_method(target: 'Target') -> Callable: ...@@ -133,6 +139,7 @@ def profile_method(target: 'Target') -> Callable:
return f return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module) -> Callable: def profile_module(module: torch.nn.Module) -> Callable:
""" """
Wrap a `call_module` node or `torch.nn` in order to Wrap a `call_module` node or `torch.nn` in order to
......
...@@ -2,7 +2,6 @@ import operator ...@@ -2,7 +2,6 @@ import operator
from typing import Any, Tuple from typing import Any, Tuple
import torch import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function
from colossalai.fx.proxy import ColoProxy
@meta_profiler_function.register(operator.getitem) @meta_profiler_function.register(operator.getitem)
......
from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx import Node, GraphModule from torch.fx import GraphModule, Node
from typing import Union, Dict, List, Tuple
from . import META_COMPATIBILITY from .._compatibility import compatibility, is_compatible_with_meta
__all__ = [ __all__ = [
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out" 'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
] ]
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node. """Calculate activation size of a node.
...@@ -29,6 +32,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: ...@@ -29,6 +32,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
return act_size return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int: def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node. """Calculate parameter size of a node.
...@@ -111,8 +115,8 @@ def is_inplace(n: Node): ...@@ -111,8 +115,8 @@ def is_inplace(n: Node):
inplace = False inplace = False
if n.op == "call_function": if n.op == "call_function":
inplace = n.kwargs.get("inplace", False) inplace = n.kwargs.get("inplace", False)
if META_COMPATIBILITY: if is_compatible_with_meta():
from .constant import ALIAS_ATEN from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN: if n.target in ALIAS_ATEN:
inplace = True inplace = True
elif n.op == "call_module": elif n.op == "call_module":
......
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py # adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw # ideas from https://pastebin.com/AkvAyJBw
from functools import partial, reduce
import operator import operator
from typing import Callable, List, Any from functools import partial, reduce
from numbers import Number from numbers import Number
from typing import Any, Callable, List
import torch import torch
aten = torch.ops.aten aten = torch.ops.aten
......
import time
from functools import partial from functools import partial
from typing import Callable, Any, Dict, Tuple from typing import Any, Callable, Dict, Tuple
import torch import torch
from torch.nn.parameter import Parameter
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.nn.parameter import Parameter
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory import activation_size, parameter_size from .memory import activation_size, parameter_size
from .constant import ALIAS_ATEN
from .tensor import MetaTensor
from .opcount import flop_mapping from .opcount import flop_mapping
import time from .tensor import MetaTensor
__all__ = ['profile_function', 'profile_module', 'profile_method'] __all__ = ['profile_function', 'profile_module', 'profile_method']
...@@ -41,6 +44,7 @@ def detach_variables(x): ...@@ -41,6 +44,7 @@ def detach_variables(x):
return x return x
@compatibility(is_backward_compatible=True)
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30 """Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
To profile the actual forward memory, we first run target in the context torch.no_grad() to get To profile the actual forward memory, we first run target in the context torch.no_grad() to get
...@@ -140,6 +144,7 @@ def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ... ...@@ -140,6 +144,7 @@ def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...
return tree_map(detach_variables, out), graphinfo return tree_map(detach_variables, out), graphinfo
@compatibility(is_backward_compatible=False)
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
""" """
Profile a Callable function with args and kwargs on meta devices. Profile a Callable function with args and kwargs on meta devices.
...@@ -277,6 +282,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G ...@@ -277,6 +282,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
return tree_map(unwrap, out), graph_info return tree_map(unwrap, out), graph_info
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable: def profile_function(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
...@@ -335,6 +341,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: ...@@ -335,6 +341,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
return f return f
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target', device: str = 'meta') -> Callable: def profile_method(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_method` node Wrap a `call_method` node
...@@ -353,6 +360,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: ...@@ -353,6 +360,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
return f return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
""" """
Wrap a `call_module` node or `torch.nn` in order to Wrap a `call_module` node or `torch.nn` in order to
......
import uuid
from copy import deepcopy from copy import deepcopy
from typing import Optional from typing import Optional
import torch import torch
from torch.utils._pytree import tree_map, tree_flatten from torch.types import _bool, _device, _dtype
from torch.types import _bool, _dtype, _device from torch.utils._pytree import tree_flatten, tree_map
import uuid
from .constant import ALIAS_ATEN from .._compatibility import compatibility
from .constants import ALIAS_ATEN
__all__ = ['MetaTensor'] __all__ = ['MetaTensor']
...@@ -15,6 +18,7 @@ def set_uuid(x): ...@@ -15,6 +18,7 @@ def set_uuid(x):
setattr(x, 'uuid', uuid.uuid4()) setattr(x, 'uuid', uuid.uuid4())
@compatibility(is_backward_compatible=False)
class MetaTensor(torch.Tensor): class MetaTensor(torch.Tensor):
""" """
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment