Commit 08f2920e authored by zhuwenwen's avatar zhuwenwen
Browse files

init colossalai, support dtk2304

parent da3f0934
Pipeline #237 failed with stages
in 0 seconds
import time
from functools import partial
from typing import Any, Callable, Dict, Tuple
import torch
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.nn.parameter import Parameter
from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping
from .tensor import MetaTensor
__all__ = ['profile_function', 'profile_module', 'profile_method']
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
cache = set()
# a global identifier for inplace ops
do_not_cache = False
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def detach_variables(x):
if isinstance(x, torch.Tensor):
requires_grad = x.requires_grad
x = x.detach()
x.requires_grad = requires_grad
return x
@compatibility(is_backward_compatible=True)
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
To profile the actual forward memory, we first run target in the context torch.no_grad() to get
the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory
by memory allocated minus the fwd_mem_out.
To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then
find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size
of args and kwargs).
We also add time stamps to profile the real forward and backward time.
Args:
target (Callable): A Callable function
args (Any): Arguments
kwargs (Any): Arguments
Returns:
Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward
time.
"""
graphinfo = GraphInfo()
# detach input from the graph
args = tree_map(detach_variables, args)
kwargs = tree_map(detach_variables, kwargs)
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# calculate fwd_mem_out
mem_stamp0 = torch.cuda.memory_allocated()
with torch.no_grad():
out = getattr(self_obj, target)(*args_tail, **kwargs)
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0
del out
# calculate fwd_mem_tmp & fwd_time
mem_stamp0 = torch.cuda.memory_allocated()
fwd_time0 = time.time()
out = getattr(self_obj, target)(*args_tail, **kwargs)
fwd_time1 = time.time()
graphinfo.fwd_time = fwd_time1 - fwd_time0
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out
# calculate bwd_mem_tmp & bwd_time
grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out)
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
bwd_time0 = time.time()
torch.autograd.backward(out, grad_tensors=grad_tensors)
bwd_time1 = time.time()
graphinfo.bwd_time = bwd_time1 - bwd_time0
mem_stamp1 = torch.cuda.max_memory_allocated()
# calculate bwd memory stats
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs)
graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0
graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out
else:
# calculate fwd_mem_out
mem_stamp0 = torch.cuda.memory_allocated()
with torch.no_grad():
out = target(*args, **kwargs)
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0
del out
# calculate fwd_mem_tmp & fwd_time
mem_stamp0 = torch.cuda.memory_allocated()
fwd_time0 = time.time()
out = target(*args, **kwargs)
fwd_time1 = time.time()
graphinfo.fwd_time = fwd_time1 - fwd_time0
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out
# calculate bwd_mem_tmp & bwd_time
grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out)
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
bwd_time0 = time.time()
torch.autograd.backward(out, grad_tensors=grad_tensors)
bwd_time1 = time.time()
graphinfo.bwd_time = bwd_time1 - bwd_time0
mem_stamp1 = torch.cuda.max_memory_allocated()
# calculate bwd memory stats
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs)
graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0
graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out
return tree_map(detach_variables, out), graphinfo
@compatibility(is_backward_compatible=False)
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""
Profile a Callable function with args and kwargs on meta devices.
Args:
target (Callable): A Callable function
args (Any): Argument
kwargs (Any): Argument
Returns:
out (Tuple[Any, ...]): The argument value that was retrieved.
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
# This subgraph traces aten level ops inside one node.
subgraph = Graph()
# `flop_count`` serves as a global dictionary to store results.
flop_count = {
Phase.FORWARD: 0,
Phase.BACKWARD: 0,
}
# FlopTensor not only get the flop statistics of a single node,
# it also build a full autograd graph for this node.
# This makes sure we can analyze the dependencies of memory, and
# decide which forward intermediate results should be kept until
# backward is executed.
# Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor):
_node: Node = None
def __repr__(self):
if self.grad_fn:
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})"
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
out = super().__torch_dispatch__(func, types, args, kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['phase'] = phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD`
if phase == Phase.FORWARD:
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
node.meta['phase'] = Phase.PLACEHOLDER
# TODO(yby): specify `saved_tensors` for backward memory estimation
node.meta['saved_tensor'] = []
if phase == Phase.BACKWARD:
node.meta['saved_tensor'] = normalize_tuple(out)
def wrap(x):
if isinstance(x, MetaTensor):
x = FlopTensor(x)
x._node = node
return x
out = tree_map(wrap, out)
return out
def wrap(x):
if isinstance(x, torch.Tensor):
x = FlopTensor(x)
if is_autogradable(x):
x.requires_grad_(True)
x._node = subgraph.create_node('placeholder',
'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['saved_tensor'] = []
return x
# Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
def pack(x):
global cache, do_not_cache
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach()
tensor.data_ptr = x._tensor.data_ptr
x._node.meta['saved_tensor'] += [tensor]
if not do_not_cache:
cache.add(x._tensor.data_ptr())
return x
def unpack(x):
return x
# `phase` will mark the phase of autograd from outside scope.
phase = Phase.FORWARD
# mark saved tensors with saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
out = getattr(self_obj, target)(*args_tail, **kwargs)
else:
out = target(*args, **kwargs)
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))):
grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)]
phase = Phase.BACKWARD
torch.autograd.backward(
out,
grad_out,
)
graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
def extract_tensor(x: Any):
if isinstance(x, MetaTensor):
tensor = x._tensor.detach()
tensor.data_ptr = x._tensor.data_ptr
return tensor
if not isinstance(x, torch.finfo):
return x
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
def unwrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
return tree_map(unwrap, out), graph_info
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, meta_info = profile_function(func)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# find the grad for parameter in args and kwargs
param_size = 0
def get_param_size(x):
nonlocal param_size
if isinstance(x, Parameter):
param_size += activation_size(x)
tree_map(get_param_size, args)
tree_map(get_param_size, kwargs)
# If there is an argument that this `call_function` is inplace, we should
# still run the profiling but discard some results regarding `target`
global do_not_cache
inplace = kwargs.get('inplace', False)
if target in OUTPUT_SAVED_OPS:
do_not_cache = True
if inplace:
do_not_cache = True
kwargs['inplace'] = False
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
kwargs['inplace'] = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
meta.bwd_mem_out -= param_size
return out, meta
f.__name__ = target.__name__
func = target
return f
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
if device == 'meta':
out, meta = _profile_meta(target, *args, **kwargs)
else:
out, meta = _profile_concrete(target, *args, **kwargs)
return out, meta
return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, meta_info = profile_module(mod)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# calculate parameter size
param_size = parameter_size(module)
# If there is an argument that this `call_module` is inplace, we should
# still run the profiling but discard some results regarding `module`.
global do_not_cache
inplace = getattr(module, 'inplace', False)
if type(module) in OUTPUT_SAVED_MOD:
do_not_cache = True
if inplace:
do_not_cache = True
module.inplace = False
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
module.inplace = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
# grad for param will not be counted
meta.bwd_mem_out -= param_size
return out, meta
f.__name__ = module.__class__.__name__
func = module.forward
return f
import torch
from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta
from .memory_utils import activation_size
if is_compatible_with_meta():
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
@compatibility(is_backward_compatible=False)
def calculate_fwd_in(n: Node) -> int:
"""A helper function to calculate `fwd_in` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
fwd_in (int): the result of `fwd_in`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
return activation_size(n.meta["fwd_in"])
@compatibility(is_backward_compatible=False)
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp` (with sharding spec)
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
Args:
n (Node): a node from the graph
Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def is_relu_like_node(n: Node) -> bool:
"""Check if a node is a ReLU-like node.
ReLU-like nodes have the following properties:
- They are either `call_function` or `call_module`
- Their output tensors are directly saved for backward
- Their input tensors are not saved for backward
An example is `torch.nn.functional.softmax` which has (forward + backward):
def forward(self, input_2):
_softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None
zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None)
detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None
_softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None
detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None
detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None
Args:
n (Node): A node from the graph
Returns:
bool: Whether the node is a ReLU-like node
"""
if n.op == 'call_function':
return n.target in OUTPUT_SAVED_OPS
elif n.op == 'call_module':
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
return False
if not is_relu_like_node(n):
return activation_size(n.meta["fwd_tmp"])
return 0
@compatibility(is_backward_compatible=False)
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
fwd_out (int): the result of `fwd_out`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def intersect(a, b):
return {k: a[k] for k in a if k in b}
fwd_in = dict()
for u in n.users:
fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)})
fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)}
return activation_size(intersect(fwd_in, fwd_out))
def calculate_fwd_time(n: Node) -> float:
"""A helper function to calculate `fwd_time` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
fwd_time (float): the result of `fwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["fwd_flop"]
def calculate_bwd_time(n: Node) -> float:
"""A helper function to calculate `bwd_time` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
bwd_time (float): the result of `bwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["bwd_flop"]
import uuid
from copy import deepcopy
from typing import Optional
import torch
from torch.types import _bool, _device, _dtype
from torch.utils._pytree import tree_flatten, tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
__all__ = ['MetaTensor']
def set_data_ptr(x):
if isinstance(x, torch.Tensor):
if not x.data_ptr():
data_ptr = uuid.uuid4()
x.data_ptr = lambda: data_ptr
@compatibility(is_backward_compatible=False)
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
`fake_device` is the device that `MetaTensor` is supposed to run on.
"""
_tensor: torch.Tensor
__slots__ = ['_tensor']
@staticmethod
def __new__(cls, elem, fake_device=None):
# Avoid multiple wrapping
if isinstance(elem, MetaTensor):
fake_device = elem.device if fake_device is None else fake_device
elem = elem._tensor
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=fake_device if fake_device is not None else elem.device,
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta`
set_data_ptr(r._tensor)
return r
def __repr__(self):
if self.grad_fn:
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
fake_device = None
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaTensor):
fake_device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
# of the input
if func in ALIAS_ATEN:
out.data_ptr = args[0].data_ptr
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
def to(self, *args, **kwargs) -> torch.Tensor:
"""An extension of `torch.Tensor.to()` to MetaTensor
Returns:
result (MetaTensor): MetaTensor
Usage:
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
>>> tensor.to(torch.uint8)
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
>>> tensor.to(torch.device('cuda:42'))
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
>>> tensor.to('vulkan')
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
device = None
for arg in args:
if isinstance(arg, str) or isinstance(arg, _device):
device = arg
if 'device' in kwargs:
device = kwargs['device']
result = super().to(*args, **kwargs)
if device is not None:
result = MetaTensor(result, fake_device=device)
return result
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
def cuda(self, *args, **kwargs):
if self.device.type == 'cuda':
return self.to(*args, **kwargs)
return self.to(*args, device='cuda', **kwargs)
import operator
import torch
from torch.fx.proxy import Proxy, Attribute
from typing import List, Union, Any
from colossalai.fx.tracer.meta_patch import meta_patched_function
__all__ = ['ColoProxy']
class ColoProxy(Proxy):
"""
ColoProxy is a proxy class which uses meta tensor to handle data-dependent control flow. The original torch.fx proxy
cannot be used to infer the condition statement, with this proxy, torch.fx can still run even with if statements.
Example::
proxy = tracer.create_proxy(...)
proxy.meta_data = torch.empty(4, 2, device='meta')
print(len(proxy)) # expect output 4
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.node._meta_data = None
@property
def meta_data(self):
return self.node._meta_data
@meta_data.setter
def meta_data(self, data: Any):
self.node._meta_data = data
@property
def has_meta_data(self):
return self._meta_data is not None
def _assert_meta_data_is_tensor(self):
assert torch.is_tensor(
self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
def _assert_has_meta_data(self):
assert self._meta_data is not None, f'Meta data is not set for {self.node.name}'
def __len__(self):
self._assert_has_meta_data()
return len(self.meta_data)
def __int__(self):
self._assert_has_meta_data()
return int(self.meta_data)
def __float__(self):
self._assert_has_meta_data()
return float(self.meta_data)
def __bool__(self):
self._assert_has_meta_data()
return self.meta_data
def __getattr__(self, k):
return ColoAttribute(self, k)
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return False
return super().__contains__(key)
def extract_meta(*args, **kwargs):
"""
This function is copied from _tracer_utils.py to avoid circular import issue.
"""
def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
elif isinstance(val, (list, tuple)):
return type(val)([_convert(ele) for ele in val])
return val
new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None
@property
def node(self):
if self._node is None:
proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {})
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*(self.root, self.attr))
meta_out = getattr(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
self._node = proxy.node
return self._node
def __call__(self, *args, **kwargs):
proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs)
method = getattr(meta_args[0].__class__, self.attr)
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
elif meta_patched_function.has(method.__name__):
meta_target = meta_patched_function.get(method.__name__)
else:
meta_target = method
meta_out = meta_target(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem
from ._meta_trace import meta_trace
from ._symbolic_trace import symbolic_trace
from .tracer import ColoTracer
import torch
from torch.fx import Graph, Node
from torch.utils._pytree import tree_map
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph:
"""Trace forward and backward graph with MetaTensor
Args:
module (torch.nn.Module): The target module for tracing.
Returns:
graph (torch.fx.Graph): The computation graph.
Usage:
>>> import torchvision.models as tm
>>> model = tm.alexnet()
>>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224))
>>> graph.print_tabular()
"""
graph = Graph()
namespace = graph._graph_namespace
class MetaProxy(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
"""
_tensor: torch.Tensor
_node: Node
__slots__ = ['_tensor', '_node']
@staticmethod
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
device=fake_device if fake_device is not None else tensor.device,
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
name = 'input'
r._node = graph.create_node('placeholder',
'placeholder', (graph._root,),
name=namespace.create_name(name, tensor))
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaProxy):
fake_device = x.device
x = x._tensor
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
x = MetaProxy(x, placeholder=True, name='weight')
return x if not hasattr(x, '_node') else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaProxy(
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
def set_node(x):
x._node = node
out = tree_map(wrap, out)
tree_map(set_node, out)
return out
def wrap(x):
return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
out = module(*args, **kwargs)
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
torch.autograd.backward(tensor,
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
retain_graph=True)
return graph
from typing import Any, Callable, Dict, Optional, Union
import torch
from colossalai.fx import ColoGraphModule
from colossalai.fx._compatibility import compatibility
from .tracer import ColoTracer
@compatibility(is_backward_compatible=True)
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
) -> ColoGraphModule:
"""
Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
constructed by recording operations seen while tracing through ``root``.
With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using
``meta_args`` only, the tracing can be done ahead of time.
Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the
argument's values.
Uses:
>>> model = ...
# if this works
>>> gm = symbolic_trace(model, concrete_args=concrete_args)
# else try this
>>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing.
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
Defaults to None.
Returns:
ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``.
Warnings:
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
from typing import List, Union, Any
from ..proxy import ColoProxy, ColoAttribute
import torch
from .meta_patch import meta_patched_function, meta_patched_module
__all__ = ['is_element_in_list', 'extract_meta']
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
if isinstance(elements, (tuple, list, set)):
for ele in elements:
if ele not in list_:
return False, ele
else:
if elements not in list_:
return False, elements
return True, None
def extract_meta(*args, **kwargs):
def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
elif isinstance(val, (list, tuple)):
return type(val)([_convert(ele) for ele in val])
return val
new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs
def compute_meta_data_for_functions_proxy(target, args, kwargs):
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
return meta_out
from .patched_bias_addition_function import *
from .patched_bias_addition_module import *
from .addbmm import Addbmm
from .addmm import Addmm
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict
from .linear import Linear
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addbmm)
@bias_addition_function.register(torch.addbmm)
class Addbmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
kwargs = {}
if 'beta' in self.kwargs:
kwargs['beta'] = self.kwargs['beta']
if 'alpha' in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha']
return kwargs
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.bmm
node_kind = 'call_function'
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
# torch.bmm does not have any kwargs
node_kwargs = {}
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return non_bias_func_proxy
def insert_sum_node(self, input_proxy, sum_dims=0):
'''
This method is used to sum the input_proxy through the sum_dims.
'''
node_kind = 'call_function'
node_target = torch.sum
node_args = (input_proxy, sum_dims)
node_kwargs = {}
sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return sum_proxy
def generate(self):
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2])
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs:
beta = kwargs['beta']
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
else:
alpha_proxy = sum_proxy
# doing the addition(temp_4 = temp_2 + temp_3)
bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy)
return bias_addition_proxy
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addmm)
@bias_addition_function.register(torch.addmm)
class Addmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
kwargs = {}
if 'beta' in self.kwargs:
kwargs['beta'] = self.kwargs['beta']
if 'alpha' in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha']
return kwargs
def transpose_other_operand_for_linear(self, other_proxy):
'''
This method is used to transpose the other operand for linear function.
For example:
input = torch.rand(3, 4)
m1 = torch.rand(3, 5)
m2 = torch.rand(5, 4)
original_output = torch.addmm(input, m1, m2)
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
# before we call the linear function.
new_output = torch.linear(m1, m2.transpose(0, 1)) + input
'''
node_kind = 'call_function'
node_target = torch.transpose
node_args = (other_proxy, 0, 1)
node_kwargs = {}
transpose_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return transpose_proxy
def generate(self):
transpose_proxy = self.transpose_other_operand_for_linear(self.args[2])
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy)
kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs:
beta = kwargs['beta']
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
else:
alpha_proxy = non_bias_linear_func_proxy
bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy)
return bias_addition_proxy
import operator
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
class BiasAdditionFunc(ABC):
"""
This class is used to construct the restructure computation graph for
call_func node with bias addition inside.
"""
def __init__(self, tracer, target, args, kwargs, substitute_func):
self.tracer = tracer
self.target = target
self.args = args
self.kwargs = kwargs
self.substitute_func = substitute_func
@abstractmethod
def extract_kwargs_from_origin_func(self):
"""
This method is used to extract the kwargs for further graph transform.
For example:
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
The kwargs for addmm function is {beta=1, alpha=1, output=None}, then we need
to insert two more operator.mul nodes for the computation graph to compute the
final result.
"""
pass
@abstractmethod
def generate(self):
"""
This method is used to construct the whole restructure computation graph for call_func node with bias
addition inside.
A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node,
a bias reshape node if needed and a bias addition node.
Use torch.addmm as an example:
The origin node is:
%addmm: call_func[target=torch.addmm](args = (%input_1, m1, m2), kwargs = {beta=1, alpha=1})
Restructured graph is:
%transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {})
%linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {})
%mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {})
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
"""
pass
def create_mul_node(self, input_proxy, coefficent):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind = 'call_function'
node_target = operator.mul
node_args = (
input_proxy,
coefficent,
)
node_kwargs = {}
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return mul_proxy
class LinearBasedBiasFunc(BiasAdditionFunc):
"""
This class is used to construct the restructure computation graph for
call_func node based on F.linear.
"""
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.nn.functional.linear
node_kind = 'call_function'
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
# non-bias linear does not have any kwargs
node_kwargs = {}
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return non_bias_func_proxy
def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy):
"""
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind = 'call_function'
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
return bias_add_proxy
func_to_func_dict = {
torch.addmm: F.linear,
torch.addbmm: torch.bmm,
F.linear: F.linear,
}
method_to_func_dict = {
torch.Tensor.addmm: F.linear,
torch.Tensor.addbmm: torch.bmm,
}
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_function.register(F.linear)
class Linear(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
assert 'bias' in self.kwargs
kwargs = {}
if 'bias' in self.kwargs:
kwargs['bias'] = self.kwargs['bias']
return kwargs
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
kwargs = self.extract_kwargs_from_origin_func()
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
return bias_addition_proxy
from .bias_addition_module import *
from .conv import *
from .linear import *
import operator
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
class BiasAdditionModule(ABC):
"""
This class is used to construct the restructure computation graph for
call_module node with bias addition inside.
"""
def __init__(self, tracer, target, args, kwargs, substitute_func):
self.tracer = tracer
self.target = target
self.args = args
self.kwargs = kwargs
self.substitute_func = substitute_func
self.weight_proxy = self._create_weight_proxy()
self.bias_proxy = self._create_bias_proxy()
def _create_weight_proxy(self):
"""
Create weight proxy, the node created by this proxy contains module weight.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
weight_node_kind = 'get_attr'
weight_node_target = self.target + '.weight'
weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {})
return weight_proxy
def _create_bias_proxy(self):
"""
Create bias proxy, the node created by this proxy contains module bias.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
bias_node_kind = 'get_attr'
bias_node_target = self.target + '.bias'
bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {})
return bias_proxy
@abstractmethod
def extract_kwargs_from_mod(self):
"""
This method is used to extract the kwargs for non-bias computation.
For example:
The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are
considered during module initilizing. However, we need to consider those attributes as kwargs
in F.conv2d.
"""
pass
def create_non_bias_func_proxy(self, input_proxy=None):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
node_kind = 'call_function'
node_target = self.substitute_func
if input_proxy is None:
input_proxy = self.args[0]
node_args = (input_proxy, self.weight_proxy)
node_kwargs = self.extract_kwargs_from_mod()
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return non_bias_func_proxy
def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy):
"""
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind = 'call_function'
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
return bias_add_proxy
@abstractmethod
def generate(self):
"""
This method is used to construct the whole restructure computation graph for call_module node with bias
addition inside.
A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node,
a bias reshape node if needed and a bias addition node.
Use Conv2d module as an example:
The origin node is:
%conv: call_module[target=conv](args = (%x,), kwargs = {})
Restructured graph is:
%conv_weight : [#users=1] = get_attr[target=conv.weight]
%conv_bias : [#users=1] = get_attr[target=conv.bias]
%conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
"""
pass
module_to_func_dict = {
torch.nn.Linear: F.linear,
torch.nn.Conv1d: F.conv1d,
torch.nn.Conv2d: F.conv2d,
torch.nn.Conv3d: F.conv3d,
}
import torch
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Conv1d)
@bias_addition_module.register(torch.nn.Conv2d)
@bias_addition_module.register(torch.nn.Conv3d)
class BiasAdditionConv(BiasAdditionModule):
def extract_kwargs_from_mod(self):
root = self.tracer.root
conv_module = root.get_submodule(self.target)
kwarg_attributes = ['groups', 'dilation', 'stride']
non_bias_kwargs = {}
for attr_name in kwarg_attributes:
if hasattr(conv_module, attr_name):
non_bias_kwargs[attr_name] = getattr(conv_module, attr_name)
if conv_module.padding_mode != "zeros":
#TODO: non zeros mode requires some extra processing for input
conv_type = type(conv_module)
if conv_type == "torch.nn.Conv1d":
padding_element = _single(0)
elif conv_type == "torch.nn.Conv2d":
padding_element = _pair(0)
elif conv_type == "torch.nn.Conv3d":
padding_element = _triple(0)
non_bias_kwargs['padding'] = padding_element
else:
non_bias_kwargs['padding'] = getattr(conv_module, 'padding')
return non_bias_kwargs
def create_bias_reshape_proxy(self, dimensions):
"""
This method is used to reshape the bias node in order to make bias and
output of non-bias convolution broadcastable.
"""
bias_shape = [1] * (dimensions - 1)
bias_shape[0] = -1
bias_reshape_node_kind = 'call_method'
bias_reshape_node_target = 'view'
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
bias_reshape_node_args, {})
return bias_reshape_proxy
def generate(self):
non_bias_conv_func_proxy = self.create_non_bias_func_proxy()
output_dims = non_bias_conv_func_proxy.meta_data.dim()
bias_reshape_proxy = self.create_bias_reshape_proxy(output_dims)
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_conv_func_proxy, bias_reshape_proxy)
return bias_addition_proxy
import torch
import torch.nn.functional as F
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Linear)
class BiasAdditionLinear(BiasAdditionModule):
def extract_kwargs_from_mod(self):
return {}
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy()
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, self.bias_proxy)
return bias_addition_proxy
import enum
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from torch.fx import Graph, Node, Proxy, Tracer
from torch.utils._pytree import tree_map
from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
Target = Union[Callable[..., Any], str]
Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
List[Any], # actually Argument
Dict[str, Any], # actually Argument
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
'Node',]]
_CScriptMethod = ['add', 'mul', 'sub', 'div']
_TorchNewMethod = [
"arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor",
"finfo"
]
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
def _truncate_suffix(s: str):
import re
return re.sub(r'_\d+$', '', s)
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
if isinstance(elements, (tuple, list, set)):
for ele in elements:
if ele not in list_:
return False, ele
else:
if elements not in list_:
return False, elements
return True, None
def default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
@compatibility(is_backward_compatible=False)
class ColoProxy(Proxy):
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs)
self._data = data
@property
def data(self):
return self._data
@data.setter
def data(self, args):
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
self._data = tree_map(wrap_fn, args)
@classmethod
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p
kwargs = {} if kwargs is None else kwargs
if proxy.data is None:
proxy.data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
return proxy
@classmethod
def from_torch_proxy(cls, proxy: Proxy):
return cls(proxy.node, proxy.tracer)
def __repr__(self):
return f"ColoProxy({self.node.name}, data={self.data})"
def __len__(self):
return len(self.data)
def __int__(self):
return int(self.data)
def __index__(self):
try:
return int(self.data)
except:
return torch.zeros(self.data.shape, dtype=torch.bool).numpy().__index__()
def __float__(self):
return float(self.data)
def __bool__(self):
return self.data
def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._data, k, None))
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return False
return super().__contains__(key)
def __isinstancecheck__(self, type):
return isinstance(self.data, type)
@property
def shape(self):
return self.data.shape
@property
def ndim(self):
return self.data.ndim
@property
def device(self):
proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
proxy.data = self.data.device
return proxy
@property
def dtype(self):
proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
proxy.data = self.data.dtype
return proxy
def to(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
def cpu(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
def cuda(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
@compatibility(is_backward_compatible=False)
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._data = data
self._node: Optional[Node] = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
@compatibility(is_backward_compatible=False)
class ColoTracer(Tracer):
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self._disable_module_getattr = False
self.proxy_buffer_attributes = True
def proxy(self, node: Node) -> 'ColoProxy':
return ColoProxy(node, self)
def create_proxy(self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p
if kind == 'placeholder':
proxy.data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
_truncate_suffix(target), None)
elif kind == 'get_attr':
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
proxy.data = attr_itr
finally:
self._disable_module_getattr = False
elif kind == 'call_function':
proxy.data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method':
self._disable_module_getattr = True
try:
if target == '__call__':
proxy.data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
proxy._data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
finally:
self._disable_module_getattr = False
elif kind == 'call_module':
mod = self.root.get_submodule(target)
unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p
self._disable_module_getattr = True
try:
proxy.data = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
finally:
self._disable_module_getattr = True
return proxy
def trace(self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
if k in non_meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names
def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
self.concrete_args = concrete_args
self.meta_args = meta_args
with _TorchTensorOverride(self):
self.graph = super().trace(root, concrete_args=concrete_args)
self.graph.lint()
return self.graph
def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
self.graph.lint()
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
lambda node: ColoProxy(self, node, n, attr_val))
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
return attr_val
@compatibility(is_backward_compatible=True)
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
) -> ColoGraphModule:
if is_compatible_with_meta():
if meta_args is not None:
root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args))
root.cpu()
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
else:
from .tracer import ColoTracer as OrigColoTracer
graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
@compatibility(is_backward_compatible=False)
class _TorchTensorOverride(object):
def __init__(self, tracer: Tracer):
self.overrides = {}
self.tracer = tracer
def __enter__(self):
def wrap_tensor_method(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
isinstance(p, ColoProxy) for p in kwargs.values())
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.tracer._disable_module_getattr = True
try:
proxy = self.tracer.create_proxy('call_function', target, args, kwargs)
finally:
self.tracer._disable_module_getattr = False
return proxy
else:
return target(*args, **kwargs)
return wrapper, target
self.overrides = {
target: wrap_tensor_method(getattr(torch, target))
for target in _TorchNewMethod
if callable(getattr(torch, target))
}
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, wrapper)
def __exit__(self, exc_type, exc_val, exc_tb):
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, orig)
from .patched_function import *
from .patched_module import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment