Commit 08f2920e authored by zhuwenwen's avatar zhuwenwen
Browse files

init colossalai, support dtk2304

parent da3f0934
Pipeline #237 failed with stages
in 0 seconds
from .activation_function import *
from .arithmetic import *
from .convolution import *
from .embedding import *
from .normalization import *
from .torch_ops import *
import torch
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.relu)
def torch_nn_func_relu(input, inplace=False):
return torch.empty(input.shape, device='meta')
import torch
from ...registry import meta_patched_function
@meta_patched_function.register(torch.matmul)
@meta_patched_function.register('matmul') # for built-in op @
def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx
d1 = input.dim()
d2 = other.dim()
shape = None
if d1 == 1 and d2 == 1:
shape = None
elif d1 == 2 and d2 == 2:
shape = (input.size(0), other.size(1))
elif d1 == 1 and d2 == 2:
shape = (other.size(1),)
elif d1 == 2 and d2 == 1:
shape = (input.size(0),)
else:
max_length = max(input.dim(), other.dim())
shape1 = list(input.shape)
shape2 = list(other.shape)
if d1 == 1:
shape1 = [1] + shape1
if d2 == 1:
shape2.append(1)
shape1 = [-1] * (max_length - d1) + list(input.shape)
shape2 = [-1] * (max_length - d2) + list(other.shape)
shape = []
for i in range(max_length):
shape.append(max(shape1[i], shape2[i]))
shape[-2] = shape1[-2]
shape[-1] = shape2[-1]
if d1 == 1:
shape.pop(-2)
if d2 == 1:
shape.pop(-1)
if shape is None:
return torch.tensor(0.0, device="meta")
return torch.empty(*shape, device="meta")
@meta_patched_function.register(torch.abs)
def torch_abs(input, *, out=None):
assert out is None, 'out is not supported yet'
return torch.empty(input.shape, device='meta')
@meta_patched_function.register(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = input.shape
_, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta")
@meta_patched_function.register(torch.nn.functional.linear)
def torch_linear(input, mat2, bias=None, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
output_shape = list(input.shape)
output_feature = list(mat2.shape)[0]
output_shape[-1] = output_feature
return torch.empty(*output_shape, device="meta")
@meta_patched_function.register(torch.addbmm)
@meta_patched_function.register(torch.Tensor.addbmm)
def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
_, n, _ = mat1.shape
_, _, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.addmm)
@meta_patched_function.register(torch.Tensor.addmm)
def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
n, _ = mat1.shape
_, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
assert out is None, 'saving to out is not supported yet'
var = torch.empty(1).squeeze(0).to('meta')
mean = torch.empty(1).squeeze(0).to('meta')
return var, mean
import collections
import math
from itertools import repeat
import torch
from ...registry import meta_patched_function
def _ntuple(n, name="parse"):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))
parse.__name__ = name
return parse
_single = _ntuple(1, "_single")
_pair = _ntuple(2, "_pair")
_triple = _ntuple(3, "_triple")
def _extract_kwargs(kwargs):
if 'stride' in kwargs:
stride = kwargs['stride']
else:
stride = 1
# TODO: process str type padding
if 'padding' in kwargs:
padding = kwargs['padding']
else:
padding = 0
if 'dilation' in kwargs:
dilation = kwargs['dilation']
else:
dilation = 1
if 'output_padding' in kwargs:
output_padding = kwargs['output_padding']
else:
output_padding = 0
return stride, padding, dilation, output_padding
@meta_patched_function.register(torch.nn.functional.conv1d)
def torch_nn_functional_conv1d(input, weight, **kwargs):
stride, padding, dilation, _ = _extract_kwargs(kwargs)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
kernel_size = weight.shape[2:]
l_in = input.shape[-1]
c_out = weight.shape[0]
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_function.register(torch.nn.functional.conv2d)
def torch_nn_functional_conv2d(input, weight, **kwargs):
stride, padding, dilation, _ = _extract_kwargs(kwargs)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
kernel_size = weight.shape[2:]
h_in, w_in = input.shape[-2:]
c_out = weight.shape[0]
h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_function.register(torch.nn.functional.conv3d)
def torch_nn_functional_conv3d(input, weight, **kwargs):
stride, padding, dilation, _ = _extract_kwargs(kwargs)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
kernel_size = weight.shape[2:]
d_in, h_in, w_in = input.shape[-3:]
c_out = weight.shape[0]
d_out = math.floor((d_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_function.register(torch.nn.functional.conv_transpose1d)
def torch_nn_functional_convtranspose1d(input, weight, **kwargs):
stride, padding, dilation, output_padding = _extract_kwargs(kwargs)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
output_padding = _single(output_padding)
kernel_size = weight.shape[2:]
l_in = input.shape[-1]
c_out = weight.shape[1]
l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
output_padding[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_function.register(torch.nn.functional.conv_transpose2d)
def torch_nn_functional_convtranspose2d(input, weight, **kwargs):
stride, padding, dilation, output_padding = _extract_kwargs(kwargs)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
output_padding = _pair(output_padding)
kernel_size = weight.shape[2:]
h_in, w_in = input.shape[-2:]
c_out = weight.shape[1]
h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
output_padding[0] + 1)
w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) +
output_padding[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_function.register(torch.nn.functional.conv_transpose3d)
def torch_nn_functional_convtranspose3d(input, weight, **kwargs):
stride, padding, dilation, output_padding = _extract_kwargs(kwargs)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
output_padding = _triple(output_padding)
kernel_size = weight.shape[2:]
d_in, h_in, w_in = input.shape[-3:]
c_out = weight.shape[1]
d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
output_padding[0] + 1)
h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) +
output_padding[1] + 1)
w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) +
output_padding[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
import torch
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.embedding)
def torch_nn_functional_embedding(input,
weight,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False):
return torch.empty(*input.shape, weight.shape[-1], device="meta")
import torch
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.layer_norm)
def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
return torch.empty(input.shape, device='meta')
@meta_patched_function.register(torch.nn.functional.batch_norm)
def torch_nn_func_batchnorm(input,
running_mean,
running_var,
weight=None,
bias=None,
training=False,
momentum=0.1,
eps=1e-05):
return torch.empty(input.shape, device='meta')
import operator
import torch
from colossalai.fx.proxy import ColoProxy
from ...registry import meta_patched_function
@meta_patched_function.register(operator.getitem)
def operator_getitem(a, b):
# copied from huggingface.utils.fx
def to_concrete(t):
if isinstance(t, torch.Tensor):
concrete = torch.ones_like(t, device="cpu")
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
concrete = concrete.to(torch.int64)
return concrete
return t
def _slice_convert(slice_obj):
attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step}
new_attrs = _slice_attr_convert(attrs)
attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step'])
return slice(*attr_dict_to_tuple)
def _slice_attr_convert(attrs):
new_attrs = {}
for key, value in attrs.items():
if isinstance(value, ColoProxy):
new_attrs[key] = value.meta_data
else:
new_attrs[key] = value
return new_attrs
if isinstance(b, tuple):
b = list(b)
for index, element in enumerate(b):
if isinstance(element, slice):
b[index] = _slice_convert(element)
b = tuple(b)
elif isinstance(b, slice):
b = _slice_convert(b)
if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
if isinstance(a, ColoProxy):
# TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a.meta_data, device="cpu"), b).to("meta")
return operator.getitem(a, b)
import torch
from ...registry import meta_patched_function
@meta_patched_function.register(torch.arange)
def torch_arange(*args, **kwargs):
n = len(args)
step = 1
if n == 1:
start = 0
end = args[0]
elif n == 2:
start, end = args
else:
start, end, step = args
if isinstance(start, float):
start = int(start)
if isinstance(end, float):
start = int(end)
if isinstance(step, float):
step = int(step)
step = kwargs.get("step", step)
dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta")
@meta_patched_function.register(torch.finfo)
def torch_finfo(*args):
return torch.finfo(*args)
@meta_patched_function.register(torch.where)
def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
@meta_patched_function.register(torch.Tensor.repeat)
def torch_tensor_repeat(self, *sizes):
shape = list(self.shape)
for i, x in enumerate(sizes):
shape[i] *= x
return torch.empty(shape, device="meta")
@meta_patched_function.register(torch.index_select)
def torch_index_select(input, dim, index, *, out=None):
shape = list(input.shape)
shape[dim] = len(index)
return torch.empty(*shape, device="meta")
@meta_patched_function.register(torch.Tensor.index_select)
def torch_tensor_index_select(self, dim, index):
return torch_index_select(self, dim, index)
@meta_patched_function.register(torch.squeeze)
def torch_squeeze(input, dim=None):
shape = list(input.shape)
if dim is not None:
if dim < 0:
dim = input.dim() + dim
if shape[dim] == 1:
shape.pop(dim)
else:
new_shape = []
for dim_value in shape:
if dim_value == 1:
continue
new_shape.append(dim_value)
shape = new_shape
return torch.empty(shape, device="meta")
@meta_patched_function.register(torch.Tensor.squeeze)
def torch_tensor_squeeze(self, dim=None):
return torch_squeeze(self, dim)
@meta_patched_function.register(torch.unsqueeze)
def torch_unsqueeze(input, dim):
shape = list(input.shape)
if dim < 0:
dim = input.dim() + 1 + dim
shape.insert(dim, 1)
return torch.empty(shape, device="meta")
@meta_patched_function.register(torch.Tensor.unsqueeze)
def torch_tensor_unsqueeze(self, dim):
return torch_unsqueeze(self, dim)
@meta_patched_function.register(torch.cat)
def torch_cat(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
dim = axis
if dim < 0:
dim = tensors[0].dim() + dim
shapes = [t.shape for t in tensors]
shape = list(shapes[0])
concatenated_dim = sum(shape[dim] for shape in shapes)
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
return torch.empty(final_shape, device="meta")
@meta_patched_function.register(torch.repeat_interleave)
def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \
"Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
shape = list(input.shape) if dim is not None else [input.numel()]
dim = dim if dim is not None else 0
dim = input.dim() + dim if dim < 0 else dim
if isinstance(repeats, int):
shape[dim] = shape[dim] * repeats
elif isinstance(repeats, torch.Tensor):
shape[dim] = repeats.sum()
return torch.empty(shape, device="meta")
@meta_patched_function.register(torch.Tensor.repeat_interleave)
def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None):
return torch_repeat_interleave(self, repeats, dim, output_size)
@meta_patched_function.register(torch.roll)
def torch_roll(input, shifts, dims=None):
return torch.empty(input.shape, device='meta')
@meta_patched_function.register(torch.full)
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
assert out is None, 'assigning result to out is not supported yet'
return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad)
@meta_patched_function.register(torch.max)
def torch_max(input, dim=None, keepdim=False, *, out=None):
assert out is None, 'assigning value to out is not supported yet'
if dim is not None:
if isinstance(dim, int):
shape = list(input.shape)
shape.pop(dim)
if keepdim:
shape.insert(dim, 1)
return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape,
device='meta',
dtype=input.dtype)
elif isinstance(dim, torch.Tensor):
# when dim is a 0D or 1D tensor, it will maintain the same shape
num_dims = dim.dim()
if num_dims in [0, 1]:
return torch.empty_like(input, device='meta')
else:
raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions")
else:
return torch.empty([], device='meta', dtype=input.dtype)
@meta_patched_function.register(torch.Tensor.cpu)
def torch_tensor_cpu(input):
return input.clone()
@meta_patched_function.register(torch.Tensor.cuda)
def torch_tensor_cuda(input, *args, **kwargs):
return input.clone()
from .activation_function import *
from .convolution import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .rnn import *
\ No newline at end of file
import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.ReLU)
@meta_patched_module.register(torch.nn.Sigmoid)
@meta_patched_module.register(torch.nn.GELU)
@meta_patched_module.register(torch.nn.Tanh)
@meta_patched_module.register(torch.nn.ReLU6)
@meta_patched_module.register(torch.nn.PReLU)
def torch_nn_non_linear_act(self, input):
return torch.empty(input.shape, device='meta')
import math
import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Conv1d)
def torch_nn_conv1d(self, input):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
l_in = input.shape[-1]
c_out = self.out_channels
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.Conv2d)
def torch_nn_conv2d(self, input):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d
h_in, w_in = input.shape[-2:]
c_out = self.out_channels
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.Conv3d)
def torch_nn_conv3d(self, input):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d
d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.ConvTranspose1d)
def torch_nn_convtranspose1d(self, input):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
l_in = input.shape[-1]
c_out = self.out_channels
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.ConvTranspose2d)
def torch_nn_convtranspose2d(self, input):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
h_in, w_in = input.shape[-2:]
c_out = self.out_channels
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.ConvTranspose3d)
def torch_nn_convtranspose3d(self, input):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
(self.kernel_size[2] - 1) + self.output_padding[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Embedding)
def torch_nn_embedding(self, input):
result_shape = input.shape + (self.embedding_dim,)
return torch.empty(result_shape, device='meta')
import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Linear)
def torch_nn_linear(self, input):
last_dim = input.shape[-1]
assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch'
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.LayerNorm)
@meta_patched_module.register(torch.nn.GroupNorm)
@meta_patched_module.register(torch.nn.BatchNorm1d)
@meta_patched_module.register(torch.nn.BatchNorm2d)
@meta_patched_module.register(torch.nn.BatchNorm3d)
def torch_nn_normalize(self, input):
# check shape
if isinstance(self, torch.nn.BatchNorm1d):
assert input.dim() in [2, 3]
elif isinstance(self, torch.nn.BatchNorm2d):
assert input.dim() == 4
elif isinstance(self, torch.nn.BatchNorm3d):
assert input.dim() == 5
# normalization maintain the same shape as the input
return input.clone()
try:
import apex
meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
except (ImportError, AttributeError):
pass
import math
import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.AvgPool1d)
def torch_nn_avgpool1d(self, input):
num_dim = input.dim()
assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
l_in = input.shape[-1]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 1
else:
return item
padding = _convert_int_to_list(self.padding)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AvgPool2d)
def torch_nn_avgpool2d(self, input):
num_dim = input.dim()
assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
h_in, w_in = input.shape[-2:]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 2
else:
return item
padding = _convert_int_to_list(self.padding)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
result_shape = tuple(input.shape[:-2]) + (
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AvgPool3d)
def torch_nn_avgpool3d(self, input):
num_dim = input.dim()
assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
d_in, h_in, w_in = input.shape[-3:]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 3
else:
return item
padding = _convert_int_to_list(self.padding)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
d_out = math.floor((d_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1)
result_shape = tuple(input.shape[:-3]) + (
d_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.MaxPool1d)
def torch_nn_maxpool1d(self, input):
num_dim = input.dim()
assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
l_in = input.shape[-1]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 1
else:
return item
padding = _convert_int_to_list(self.padding)
dilation = _convert_int_to_list(self.dilation)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.MaxPool2d)
def torch_nn_maxpool2d(self, input):
num_dim = input.dim()
assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
h_in, w_in = input.shape[-2:]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 2
else:
return item
padding = _convert_int_to_list(self.padding)
dilation = _convert_int_to_list(self.dilation)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
result_shape = tuple(input.shape[:-2]) + (
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.MaxPool3d)
def torch_nn_maxpool3d(self, input):
num_dim = input.dim()
assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
d_in, h_in, w_in = input.shape[-3:]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 3
else:
return item
padding = _convert_int_to_list(self.padding)
dilation = _convert_int_to_list(self.dilation)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
d_out = math.floor((d_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)
result_shape = tuple(input.shape[:-3]) + (
d_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool1d)
def torch_nn_adapative_pooling_1d(self, input):
assert input.dim() in [2, 3]
if isinstance(self.output_size, int):
output_size = (self.output_size,)
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-1]) + output_size
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool2d)
def torch_nn_adapative_pooling_2d(self, input):
assert input.dim() in [3, 4]
if isinstance(self.output_size, int):
output_size = (self.output_size,) * 2
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-2]) + output_size
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool3d)
def torch_nn_adapative_pooling_3d(self, input):
assert input.dim() in [4, 5]
if isinstance(self.output_size, int):
output_size = (self.output_size,) * 3
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-3]) + output_size
return torch.empty(result_shape, device='meta')
from typing import Optional
import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.GRU)
@meta_patched_module.register(torch.nn.RNN)
def torch_nn_rnn(self, input, hx):
assert input.shape[
-1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch'
assert hx.shape[
-1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch'
d = 2 if self.bidirectional else 1
return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx
class PatchRegistry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
self.store[source] = func
return func
return wrapper
def get(self, source):
assert source in self.store
target = self.store[source]
return target
def has(self, source):
return source in self.store
meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')
#!/usr/bin/env python
"""
tracer.py:
Implemented a tracer which supports control flow and user-defined meta arguments.
The implementation is partly inspired HuggingFace's fx tracer
"""
import enum
import functools
import inspect
import operator
from contextlib import contextmanager
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.fx import Node, Tracer
from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods
from torch.fx.proxy import ParameterProxy, Proxy
from ..proxy import ColoProxy
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from .registry import (
bias_addition_function,
bias_addition_method,
bias_addition_module,
meta_patched_function,
meta_patched_module,
)
__all__ = ['ColoTracer']
class TracerType(enum.Enum):
DEFAULT = 1
META = 2
class ColoTracer(Tracer):
"""
ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module.
This tracer is initialized in the same way as the original torch.fx.Tracer.
Usage::
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 10)
self.linear2 = nn.Linear(10, 10)
def forward(self, x, y):
x1 = self.linear1(x)
y1 = self.linear2(y)
if x1.dim() == 2:
return x1 + y1
else:
return x1 - y1
model = Model()
tracer = ColoTracer()
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
"""
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tracer_type = TracerType.META
self.proxy_cls = ColoProxy
# whether the tracer will record the usage of torch.utils.checkpoint
self.trace_act_ckpt = trace_act_ckpt
# whether the current tracing occurs within the activation checkpoint functions
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count = 0
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"]
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
"""
Create a proxy for different kinds of operations.
"""
if self.tracer_type == TracerType.DEFAULT:
# since meta_args is not given
# we just fall back to the original torch.fx.Tracer
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
return proxy
# if graph is traced for auto parallelism module, some extra node will be added during
# graph construction to deal with the compatability between bias addition and all reduce.
# if no extra manipulation is applied, we just pass the origin arguments to create_proxy function
# to create node on computation graph
origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
# dispatch the arguments generator depending on the kind and target in origin arguments.
args_metas, _ = extract_meta(*args, **kwargs)
handle = None
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
else:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute)
finally:
self._disable_module_getattr = False
if handle is not None:
return handle.generate()
# create nodes using patched arguments
proxy = super().create_proxy(*origin_arguments)
proxy: ColoProxy
meta_out = self._meta_data_computing(
kind,
target,
args,
kwargs,
)
proxy.meta_data = meta_out
return proxy
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
lambda node: ParameterProxy(self, node, n, attr_val))
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
module_qualified_name = self.path_of_module(m)
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
# which means customized modules are not leaf module by default
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
def proxy(self, node) -> Proxy:
"""
Returns a ColoProxy object.
"""
return self.proxy_cls(node, self)
def _configure_tracer_type(self, tracer_type: TracerType):
if tracer_type == TracerType.DEFAULT:
self.proxy_cls = Proxy
self.tracer_type = TracerType.DEFAULT
elif tracer_type == TracerType.META:
self.proxy_cls = ColoProxy
self.tracer_type = TracerType.META
else:
raise ValueError(f"Unrecognised tracer type {tracer_type}")
def _meta_data_computing(self, kind, target, args, kwargs):
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
meta_out = self.meta_args[target]
return meta_out
if target in self.orig_torch_tensor_methods:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if "device" in kwargs:
kwargs["device"] = "meta"
try:
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
if kind == "call_function":
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
# fetch patched method
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
else:
meta_target = method
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if meta_patched_module.has(mod_type):
meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.parameter.Parameter):
meta_out = torch.nn.Parameter(attr_itr.to(device="meta"))
elif isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
finally:
self._disable_module_getattr = False
else:
return None
except Exception as e:
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
return meta_out
def trace(self,
root: nn.Module,
concrete_args: Optional[Dict[str, Tensor]] = None,
meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
"""
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
Args:
root (nn.Module): a `nn.Module` object to trace the computation graph
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
"""
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
if len(meta_args) == 0:
self._configure_tracer_type(TracerType.DEFAULT)
else:
self._configure_tracer_type(TracerType.META)
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
if k in non_meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names
def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
# assign as attributed for late reference
def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items():
if not should_be_meta:
assert not torch.is_tensor(v) or not v.is_meta, \
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
else:
assert v.is_meta == should_be_meta, \
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
_check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True)
self.concrete_args = concrete_args
self.meta_args = meta_args
self.patched_torch_tensor_methods = {}
if self.tracer_type == TracerType.META:
# wrap the torch tensor constructing methods so that they are captured in the graph
self.patched_torch_tensor_methods = {
target: wrap_tensor_constructor_method(getattr(torch, target))
for target in self._TORCH_METHODS_TO_PATCH
}
# patch these methods to replace their original use
for name, (wrapper, orig) in self.patched_torch_tensor_methods.items():
setattr(torch, name, wrapper)
# cache these methods so that we can detect whether a method call
# should be patched during tracing
self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()]
try:
# to track the usage of torch.utils.checkpoint
with self.trace_activation_checkpoint(enabled=self.trace_act_ckpt):
self.graph = super().trace(root, concrete_args=concrete_args)
finally:
# recover the patched methods
for name, (_, orig) in self.patched_torch_tensor_methods.items():
setattr(torch, name, orig)
if self.tracer_type == TracerType.DEFAULT:
return self.graph
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
return self.graph
@contextmanager
def trace_activation_checkpoint(self, enabled: bool):
if enabled:
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activaton checkpoint part
self.inside_torch_checkpoint_func = True
out = run_function(*args)
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count += 1
return out
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
"We do not implement the backward pass as we only trace the forward pass.")
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
yield
if enabled:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
return node
def wrap_tensor_constructor_method(target):
def look_for_proxy(*args, **kwargs):
# find in pos vars
for arg in args:
if isinstance(arg, Proxy):
return arg
if isinstance(arg, (tuple, list)):
return look_for_proxy(*arg)
# find in keyword vars
for k, v in kwargs.items():
if isinstance(v, Proxy):
return v
if isinstance(v, (tuple, list)):
return look_for_proxy(*v)
return None
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = look_for_proxy(*args, **kwargs)
if proxy is not None:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(colo_proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
colo_proxy = ColoProxy(proxy.node)
colo_proxy.meta_data = meta_out
return colo_proxy
else:
# this is called directly when the inputs do not contain proxy
# e.g. torch.ones(4) where the input is static
return target(*args, **kwargs)
return wrapper, target
# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__,
# and add meta_data attribute to the created proxy.
for method in magic_methods:
def _scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
proxy = tracer.create_proxy('call_function', target, args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
impl.__name__ = method
as_magic = f'__{method.strip("_")}__'
setattr(ColoProxy, as_magic, impl)
_scope(method)
def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name.strip("_")}__'
def impl(self, rhs):
target = getattr(operator, orig_method_name)
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
impl.__name__ = method_name
impl.__qualname__ = method_name
setattr(ColoProxy, method_name, impl)
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .gemini_mgr import GeminiManager
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
__all__ = [
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager',
'search_chunk_configuration'
]
from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
from .manager import ChunkManager
from .search_utils import classify_params_by_dp_degree, search_chunk_configuration
from .utils import init_chunk_manager
__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment