Unverified Commit 7ea6bc7f authored by Boyuan Yao's avatar Boyuan Yao Committed by GitHub
Browse files

[autoparallel] Patch tensor related operations meta information (#2789)

* [autoparallel] tensor related meta information prototype

* [autoparallel] tensor related meta information

* [autoparallel] tensor related meta information

* [autoparallel] tensor related meta information

* [autoparallel] tensor related meta information
parent a5721229
...@@ -5,3 +5,4 @@ from .embedding import * ...@@ -5,3 +5,4 @@ from .embedding import *
from .linear import * from .linear import *
from .norm import * from .norm import *
from .pooling import * from .pooling import *
from .tensor import *
...@@ -14,7 +14,6 @@ __all__ = ["avgpool_meta_info", "maxpool_meta_info"] ...@@ -14,7 +14,6 @@ __all__ = ["avgpool_meta_info", "maxpool_meta_info"]
@meta_register.register(torch.nn.AdaptiveAvgPool1d) @meta_register.register(torch.nn.AdaptiveAvgPool1d)
@meta_register.register(torch.nn.AdaptiveAvgPool2d) @meta_register.register(torch.nn.AdaptiveAvgPool2d)
@meta_register.register(torch.nn.AdaptiveAvgPool3d) @meta_register.register(torch.nn.AdaptiveAvgPool3d)
@meta_register.register(torch.flatten)
def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta info for AdaptiveAvgPool """Meta info for AdaptiveAvgPool
The aten graph of AdaptiveAvgPool is The aten graph of AdaptiveAvgPool is
......
from typing import Callable, List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
__all__ = ["tensor_related_metainfo"]
def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: float = 0) -> Callable:
"""torch.Tensor related metainfo generator template
Args:
bwd_mem_out_factor (float, optional): backward activation memory cost factor. Defaults to 1.
bwd_mem_tmp_factor (float, optional): backward temp memory cost factor. Defaults to 0.
Returns:
Callable: torch.Tensor related metainfo generator
"""
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.Tensor related metainfo generator
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
outputs = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
# compute costs are all zero
compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
# memory costs
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size(outputs) * 2, parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(outputs) * bwd_mem_out_factor,
parameter=0,
temp=activation_size(outputs) * bwd_mem_tmp_factor,
buffer=0)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
if isinstance(outputs, tuple) or isinstance(outputs, list) or isinstance(outputs, dict):
# tuple of tensors
fwd_out = [torch.zeros_like(tensor) for tensor in outputs]
else:
# enaged_tensors is a single tensor
fwd_out = [torch.zeros_like(outputs)]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
return meta_func
# register torch.Tensor related metainfo
# (0, 0)
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
torch.arange])(tensor_related_metainfo(0, 0))
# (1, 0)
meta_register.register([
torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
torch.Tensor.split, torch.split, torch.Tensor.view
])(tensor_related_metainfo(1, 0))
# (1, 1)
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
class SplitModule(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x.split(512, dim=0)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
def test_tensor_meta_info():
"""test tensor related meta information
We will just use torch.Tensor.split for the test
"""
meta_func = meta_register.get(torch.Tensor.split)
# construct meta tensors
input_tensor = torch.rand(1024, 1024, device="meta")
output_tensor = input_tensor.split(512, dim=0)
# construct operation data
input_data = OperationData(
name="input",
data=input_tensor,
type=OperationDataType.ARG,
logical_shape=input_tensor.shape,
)
output_data = OperationData(
name="output",
data=output_tensor,
type=OperationDataType.OUTPUT,
logical_shape=input_tensor.shape,
)
split_info_data = OperationData(
name='split_info',
type=OperationDataType.ARG,
data=0,
logical_shape=None,
)
# construct args
args = [input_data, output_data, split_info_data]
kwargs = {'inplace': False}
# estimated results
compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
# actual results
model = SplitModule()
input_real_tensor = torch.rand(1024, 1024).cuda()
input_real_tensor.requires_grad = True
# fwd
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
output_real_tensor = model(input_real_tensor)
fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
# bwd
upstream_grad = [torch.rand_like(tensor) for tensor in output_real_tensor]
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
torch.autograd.backward(output_real_tensor, upstream_grad)
bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
print_results([input_real_tensor], output_real_tensor, compute_cost, memory_cost, fwd_allocated, fwd_peak,
bwd_allocated, bwd_peak)
if __name__ == "__main__":
test_tensor_meta_info()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment