Unverified Commit 4f596932 authored by Super Daniel's avatar Super Daniel Committed by GitHub
Browse files

[fx] provide a stable but not accurate enough version of profiler. (#1547)

* [fx] compute memory stat and flop count for MetaInfoProp.

* [fx] modify node attribute.

* [fx] modify ckpt_chen.

* [fx] fix compatibility.

* [fx] fix import error.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip if torch 1.11.0.

* [fx] recover MetaInfoProp support for PyTorch 1.11.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix import error.
parent 7d49e7b2
try:
from ._meta_registrations import *
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser)
......
......@@ -181,6 +181,12 @@ def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
return grad_in
@register_meta(aten.hardtanh_backward.default)
def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int):
grad_in = torch.empty_like(input)
return grad_in
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return torch.empty_like(input)
......@@ -321,3 +327,17 @@ def meta_index_Tensor(self, indices):
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return torch.empty((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
return torch.empty_like(condition)
......@@ -73,10 +73,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, '__activation__')
temp += getattr(n, 'fwd_out')
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += getattr(n, '__activation__')
x += getattr(n, 'fwd_out')
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
......
from operator import add, getitem
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from typing import Any, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list
from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
@compatibility(is_backward_compatible=True)
......@@ -71,14 +68,6 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
"""
args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args)
return super().run(*args, initial_env, enable_io_processing)
@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
"""
......@@ -93,8 +82,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
Any: The result of executing ``n``
"""
result, profile = super().run_node(n)
profile: MetaProfile
result, flop_count, mem_stat = super().run_node(n)
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
......@@ -106,12 +94,17 @@ class MetaInfoProp(torch.fx.Interpreter):
n.meta['tensor_meta'] = meta
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', profile.param + profile.activation)
setattr(n, '__param__', profile.param)
setattr(n, '__activation__', profile.activation)
setattr(n, '__flops__', profile.flops)
setattr(n, '__macs__', profile.macs)
setattr(n, 'node_size', mem_stat[1])
setattr(n, 'fwd_flop', flop_count[0])
setattr(n, 'bwd_flop', flop_count[1])
setattr(n, 'fwd_tmp', mem_stat[0])
setattr(n, 'fwd_out', mem_stat[1])
setattr(n, 'bwd_tmp', mem_stat[2])
setattr(n, 'bwd_out', mem_stat[3])
n.meta['type'] = type(result)
for param in self.module.parameters():
param.grad = None
return result
# Main Node running APIs
......@@ -132,11 +125,12 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)
return result, (0, 0), (0, activation_size(result), 0, 0)
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
......@@ -153,10 +147,10 @@ class MetaInfoProp(torch.fx.Interpreter):
Return:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
# A get_attr node never has parameters, activations, FLOPs, or MACs
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
......@@ -172,7 +166,8 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
......@@ -191,7 +186,8 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
return profile_method(target)(*args, **kwargs)
......@@ -209,7 +205,8 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
......@@ -231,9 +228,11 @@ class MetaInfoProp(torch.fx.Interpreter):
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
Any: The return value referenced by the output node
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
return args[0], MetaProfile(0, 0, 0, 0)
return args[0], (0, 0), (0, 0, 0, 0)
def propagate(self, *args):
"""
......
from .meta_tensor import MetaTensor
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
from .profiler_module import *
from .profiler import *
from ... import META_COMPATIBILITY
if META_COMPATIBILITY:
from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module, _profile
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
from .memory import parameter_size, activation_size
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module
from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
__all__ = ['profile_function', 'profile_module', 'profile_method']
CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Unfortunately, backward memory cost and FLOPs are estimated results.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target)
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
fwd_out = activation_size(out)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
else:
profiler = meta_profiler_function.get(target.__name__)
fwd_flop, _ = profiler(*args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
f.__name__ = target.__name__
func = target
return f
def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
Warnings:
This is not fully implemented and you may follow the error message to debug.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return f
def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if getattr(module, 'inplace', False):
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
f.__name__ = module.__class__.__name__
func = module.forward
return f
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment