Commit 7bc5a8e3 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents e6748d82 0f785cb1
from .activation import *
from .binary_elementwise_ops import *
from .conv import *
from .embedding import *
from .linear import *
from .non_spmd import *
from .norm import *
from .pooling import *
from .tensor import *
from .where import *
from typing import Callable, List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ["elementwise_meta_info"]
def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0) -> Callable:
"""This is a function to create the meta information generator for elementwise operations
Args:
temp_mem_scale (float, optional): temp memory scaling factor for backward. Defaults to 0.
buffer_mem_scale (float, optional): buffer memory scaling factor for forward. Defaults to 0.
Returns:
Callable: meta information generator
"""
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
input_tensor = next(
filter(
lambda x:
(x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim',
args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = 1 if kwargs.get('inplace', False) else 0
flop_counter = elementwise_flop_counter(1, 0)
# calculate compute cost
fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: if in_place is True, we will not create a new tensor in forward
fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace),
parameter=0,
temp=0,
buffer=activation_size(input_tensor) * buffer_mem_scale)
# temp_mem_scale is for situation like softmax backward
# the buffer will be removed during backward phase
bwd_memory_cost = MemoryCost(
activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
parameter=0,
temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
buffer=0)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
return meta_func
# register meta information
# (0, 0)
meta_register.register([torch.nn.ReLU, torch.nn.functional.relu, torch.tanh])(elementwise_meta_info(0, 0))
# (1, 0)
meta_register.register([torch.nn.Softmax, torch.nn.functional.softmax])(elementwise_meta_info(1, 0))
# (0, 0.25) for dropout, the buffer is in bool type so that the buffer memory cost is 0.25 times of input tensor
meta_register.register([torch.nn.Dropout, torch.nn.functional.dropout])(elementwise_meta_info(0, 0.25))
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info']
@meta_register.register(BCAST_FUNC_OP)
def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta information generator for binary elementwise operations
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify
this behavior, it is critical for better memory estimation.
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
# construct forward args for flop mapping
fwd_in_args = [opdata.data for opdata in input_op_data]
fwd_out_args = [output_op_data.data]
# calculate cost
# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost(
activation=activation_size(output_op_data.data),
parameter=param_mem_cost,
)
bwd_mem_cost = MemoryCost(
activation=activation_size(fwd_in_args),
parameter=param_mem_cost,
)
# total cost
total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['convnd_meta_info']
@meta_register.register(torch.nn.Conv1d)
@meta_register.register(torch.nn.Conv2d)
@meta_register.register(torch.nn.Conv3d)
@meta_register.register(torch.nn.functional.conv1d)
@meta_register.register(torch.nn.functional.conv2d)
@meta_register.register(torch.nn.functional.conv3d)
def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator
The atens graph of torch.nn.Convnd with bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=3] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
The atens graph of torch.nn.Convnd without bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None], [None, None], [None, None], None, [None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None], [None, None], [None, None], None, [None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
has_bias: bool = False
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
if len(args) == 4:
weight_tensors = [args[1].data, args[3].data]
else:
weight_tensors = [args[1].data]
# check if conv has bias
if len(weight_tensors) > 1:
has_bias = True
# bias tensor's shape only has one dimension
if len(weight_tensors[0].shape) == 1:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor, bias_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
# construct input args for forward
fwd_args = [None] * 9
# weight and input
fwd_args[0] = input_tensor
fwd_args[1] = weight_tensor
fwd_args[2] = bias_tensor if has_bias else None
# transpose indicator should be set to False
fwd_args[6] = False
# construct input args for backward
bwd_args = [None] * 11
# weight and input
bwd_args[0] = output_tensor
bwd_args[1] = input_tensor
bwd_args[2] = weight_tensor
bwd_args[-1] = [True, True, True] if has_bias else [True, True, False]
# calculate cost
# the fwd op with compute cost is convolution.default
# the bwd op with compute cost is convolution_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ["embedding_meta_info"]
@meta_register.register(torch.nn.Embedding)
def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Embedding metainfo generator
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
weight_tensor = next(filter(lambda x: x.type == OperationDataType.PARAM, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
# compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
[weight_tensor])
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=0,
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor)]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor)]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from functools import reduce
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['linear_meta_info', 'matmul_meta_info']
@meta_register.register(torch.nn.functional.linear)
@meta_register.register(torch.nn.Linear)
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Linear & torch.nn.functional.linear meta info generator
NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
but we will hold the bias mechanism in the linear metainfo generator for future use.
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
%zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
%sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {})
%view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
The one without bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {})
%zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
%mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
"""
has_bias: bool = False
input_tensor = args[0].data
output_tensor = args[2].data
if len(args) == 4:
weight_tensors = [args[1].data, args[3].data]
else:
weight_tensors = [args[1].data]
# process the dimension of input and output
if len(input_tensor.shape) > 2:
input_tensor: torch.Tensor
input_tensor = input_tensor.view(-1, input_tensor.shape[-1])
if len(output_tensor.shape) > 2:
output_tensor: torch.Tensor
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
if len(weight_tensors) > 1:
has_bias = True
if len(weight_tensors[0].shape) == 2:
weight_tensor, bias_tensor = weight_tensors
else:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
if has_bias:
# calculate cost with bias
# the fwd op with compute cost is addmm
# the bwd op with compute cost is mm * 2 and sum.dim_IntList
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
else:
# calculate cost without bias
# the fwd op with compute cost is mm
# the bwd op with compute cost is mm * 2
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@meta_register.register(torch.matmul)
def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.matmul meta info generator
There are several cases for torch.matmul:
1. Vector-vector multiplication => no temp memory, forward memory cost is 1 element (could be neglected), backward memory cost is the same
as two input vectors.
2. Matrix-vector multiplication => if the first input is matrix, no temp memory is needed, otherwise, there is a temp memory in the backward
phase for the transpose of the matrix. The forward memory cost is the size of output tensor, backward memory cost is the size of the two inputs; if
the first input is vector, the forward memory cost is the size of the output tensor, and during the backward phase, it will allocate a temp memory
the same size as the input matrix, and allocate memory for the gradient of two inputs.
3. Batched Matrix-vector multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of
output tensor, backward memory cost is the size of the two inputs; if the second input is the batched matrix, the matmul will allocate memory for
the gradient of the batched matrix in the forward phase (as they create a new tensor without the former batches), so the forward memory cost is
the output tensor and the newly created matrix (take the same amount of memory of the input batched matrix). During the backward phase, it will
allocate a temp memory the same size as input batched matrix, and allocate a tensor for the gradient of the input vector. The gradient of the batched
matrix will be stored in the memory allocated during the forward phase.
3. Matrix-matrix multiplication => no temp memory, forward memory is the size of output tensor, backward memory is the size of the two inputs
4. Batched matrix-matrix multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of two
inputs and backward memory cost is the size of the output tensor; if the second input is the batched matrix, during the forward phase it will allocate
memory for the output and gradient of the second input, and has a temp memory the same size as the output, during the backward phase, it
will allocate memory for the gradient of the first input and has a temp memory which is as big as output and the second input.
5. Batched matrix-batched matrix multiplication => if the two inputs have the same batch dimensions, no temp memory, the forward memory cost is the size
of output, backward memory cost is the size of the two inputs; it the two inputs have different batch dimensions, during the forward phase it will allocate
memory of the expanded inputs (so that the batch dimensions could match) and the output, and during the backward phase, it has a temp memory of the size of
two expanded inputs, and it will allocate memory for the gradient of the two inputs and discard the expanded inputs allocated during the forward phase.
Returns:
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
"""
# Get input and output tensors
input_tensors = [args[0].data, args[1].data]
output_tensors = [args[-1].data]
# Check dimension
if all(len(tensor.shape) == 1 for tensor in input_tensors):
# Dot
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
# gemv case 1: matrix-vector multiplication
# &
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[1]],
output_tensors) + \
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
[output_tensors[0].reshape(-1)])
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[0]],
output_tensors
) + \
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
# gemm & batched gemm case 1: batched matrix-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
[input_tensors[1]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
0, 1)
], [output_tensors[0].transpose(-2, -1)])
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
[input_tensors[0]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
# Fetch shape of the two inputs and see if the batch dimensions are the same
_is_batch_dims_same = True
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
if shape_0 != shape_1:
_is_batch_dims_same = False
break
else:
_is_batch_dims_same = False
# retireve dimensions
input_dim_00 = input_tensors[0].shape[-2]
input_dim_01 = input_tensors[0].shape[-1]
input_dim_10 = input_tensors[1].shape[-2]
input_dim_11 = input_tensors[1].shape[-1]
output_dim_0 = output_tensors[0].shape[-2]
output_dim_1 = output_tensors[0].shape[-1]
if _is_batch_dims_same:
# Case 1: batch dimensions are the same
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
-1, input_dim_10, input_dim_11)
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
else:
# Case 2: batch dimensions are different
batch_dims = output_tensors[0].shape[:-2]
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_00,
input_dim_01,
device="meta")
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_10,
input_dim_11,
device="meta")
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[extended_input_1]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
[extended_input_0]
)
fwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = input_tensors
fwd_buffer = []
fwd_out = output_tensors
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
import operator
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ["non_spmd_meta_info"]
@meta_register.register(torch.Size)
@meta_register.register(torch.Tensor.size)
@meta_register.register(torch.finfo)
@meta_register.register(operator.le)
def non_spmd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Non-SPMD node meta information generator
Those nodes will not be handled by SPMD solver, so we just return all zero meta information for it
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
memory_cost = TrainCycleItem(fwd=MemoryCost(), bwd=MemoryCost(), total=MemoryCost())
fwd_in, fwd_buffer, fwd_out = [], [], []
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info']
@meta_register.register(torch.nn.BatchNorm1d)
@meta_register.register(torch.nn.BatchNorm2d)
@meta_register.register(torch.nn.BatchNorm3d)
def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""BatchNorm1d, BatchNorm2d, BatchNorm3d, meta info generator
The aten graph of BatchNorm2d is like
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%cudnn_batch_norm_default : [#users=4] = call_function[target=torch.ops.aten.cudnn_batch_norm.default](args = (%input_2, None, None, None, None, None, None, None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%cudnn_batch_norm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%cudnn_batch_norm_backward_default : [#users=3] = call_function[target=torch.ops.aten.cudnn_batch_norm_backward.default](args = (%detach_default, %zeros_like_default, None, None, None, %detach_default_1, %detach_default_2, None, %detach_default_3), kwargs = {})
%detach_default_4 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_5 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_4,), kwargs = {})
%detach_default_6 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_7 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_6,), kwargs = {})
%detach_default_8 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_9 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_8,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
mean_tensor = next(filter(lambda x: x.name == "running_mean", args)).data
var_tensor = next(filter(lambda x: x.name == "running_var", args)).data
num_batch = next(filter(lambda x: x.name == "num_batches_tracked", args)).data
# construct fwd args
# the fwd inputs are input, weight, bias, running_mean, running_var and some other args
# indicating the status of the module
# the fwd outputs are output, saved mean, saved inv std and num batches tracked
fwd_in_args = [input_tensor, weight_tensor, bias_tensor, mean_tensor, var_tensor, True, 0.1, 1e-5]
fwd_out_args = [output_tensor, mean_tensor, var_tensor, num_batch]
# construct bwd args
# the bwd inputs are upstream grad, input, weight, running_mean, running_var, saved mean,
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args = [
output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
]
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
# calculate cost
fwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
[input_tensor, output_tensor, mean_tensor, var_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@meta_register.register(torch.nn.LayerNorm)
def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""LayerNorm meta information
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
# construct needed tensors
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
running_mean = torch.rand(input_tensor.shape[0], 1, device='meta')
running_var = torch.rand(input_tensor.shape[0], 1, device='meta')
# construct args
fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor]
fwd_out_args = [output_tensor]
bwd_in_args = [input_tensor, output_tensor, [input_tensor.shape[0]]]
bwd_out_args = [weight_tensor, bias_tensor]
# compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.native_layer_norm.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten.native_layer_norm_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
[input_tensor, output_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=compute_size_in_bytes([running_mean, running_var]))
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([running_mean, running_var]),
buffer=compute_size_in_bytes([running_mean, running_var]))
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ["avgpool_meta_info", "maxpool_meta_info"]
@meta_register.register(torch.nn.AdaptiveAvgPool1d)
@meta_register.register(torch.nn.AdaptiveAvgPool2d)
@meta_register.register(torch.nn.AdaptiveAvgPool3d)
def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta info for AdaptiveAvgPool
The aten graph of AdaptiveAvgPool is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%_adaptive_avg_pool2d_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d.default](args = (%input_2, [None, None]), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%_adaptive_avg_pool2d_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%_adaptive_avg_pool2d_backward_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d_backward.default](args = (%zeros_like_default, %detach_default), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_adaptive_avg_pool2d_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = kwargs.get("inplace", False)
# construct forward args for flop mapping
fwd_in_args = [input_tensor]
fwd_out_args = [output_tensor]
# construct backward args for flop mapping
bwd_in_args = [output_tensor]
bwd_out_args = [input_tensor]
# calculate cost
# the fwd op with compute cost is _adaptive_avg_pool2d.default
# the bwd op with compute cost is _adaptive_avg_pool2d_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(output_tensor))
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(input_tensor))
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
@meta_register.register(torch.nn.MaxPool1d)
@meta_register.register(torch.nn.MaxPool2d)
@meta_register.register(torch.nn.MaxPool3d)
def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta info for MaxPool
The aten graph of MaxPool is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%max_pool2d_with_indices_default : [#users=2] = call_function[target=torch.ops.aten.max_pool2d_with_indices.default](args = (%input_2, [None, None], [None, None]), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%max_pool2d_with_indices_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_default,), kwargs = {})
%max_pool2d_with_indices_backward_default : [#users=1] = call_function[target=torch.ops.aten.max_pool2d_with_indices_backward.default](args = (%zeros_like_default, %detach_default, [None, None], [None, None], [None, None], [None, None], None, %detach_default_1), kwargs = {})
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_backward_default,), kwargs = {})
%detach_default_3 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_2,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
# construct forward args for flop mapping
fwd_in_args = [input_tensor]
fwd_out_args = [output_tensor]
# construct backward args for flop mapping
bwd_in_args = [output_tensor]
bwd_out_args = [input_tensor]
# construct index matrix
index_matrix = torch.zeros_like(output_tensor, device="meta", dtype=torch.int64)
# calculate cost
# the fwd op with compute cost is max_pool2d_with_indices.default
# the bwd op with compute cost is max_pool2d_with_indices_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: the index matrix will be discarded in backward phase
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
temp=compute_size_in_bytes(index_matrix))
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ["tensor_related_metainfo"]
def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: float = 0) -> Callable:
"""torch.Tensor related metainfo generator template
Args:
bwd_mem_out_factor (float, optional): backward activation memory cost factor. Defaults to 1.
bwd_mem_tmp_factor (float, optional): backward temp memory cost factor. Defaults to 0.
Returns:
Callable: torch.Tensor related metainfo generator
"""
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.Tensor related metainfo generator
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
outputs = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
# compute costs are all zero
compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
# memory costs
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
parameter=0,
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
buffer=0)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
if isinstance(outputs, tuple) or isinstance(outputs, list) or isinstance(outputs, dict):
# tuple of tensors
fwd_out = [torch.zeros_like(tensor) for tensor in outputs]
else:
# enaged_tensors is a single tensor
fwd_out = [torch.zeros_like(outputs)]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
return meta_func
# register torch.Tensor related metainfo
# (0, 0)
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
torch.arange])(tensor_related_metainfo(0, 0))
# (1, 0)
meta_register.register([
torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
torch.Tensor.split, torch.split, torch.Tensor.view
])(tensor_related_metainfo(1, 0))
# (1, 1)
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ["where_meta_info"]
@meta_register.register(torch.where)
def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.where meta information generator
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
condition_tensor, x_tensor, y_tensor, output_tensor = [arg.data for arg in args]
# compute cost
fwd_compute_cost = 0
# if we need to broadcast the condition tensor, during backward we need to do a reduce_sum
bwd_compute_cost = 0
if x_tensor.shape != output_tensor.shape:
bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [x_tensor])
if y_tensor.shape != output_tensor.shape:
bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [y_tensor])
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
# during the forward phase, torch.where will allocate memory for output tensor and condition tensor
# during the backward phase, torch.where will allocate temp memory which is 3 times as output tensor, then generate
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
parameter=0,
temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
activation_size([x_tensor, y_tensor]),
buffer=0)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [condition_tensor]
fwd_buffer = []
fwd_out = [output_tensor]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
__all__ = ['Registry']
class Registry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
for element in source:
self.store[element] = func
else:
self.store[source] = func
return func
return wrapper
def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry'
target = self.store[source]
return target
def has(self, source):
return source in self.store
meta_register = Registry('meta')
from typing import Callable, List
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['ShardMetaInfo']
class ShardMetaInfo:
"""ShardMetaInfo class
This class is used to store meta info based on sharding strategy and the given
target function.
"""
def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None:
# compute cost of forward and backward computation
self.compute_cost: TrainCycleItem
# compute memory cost of forward and backward phase
self.memory_cost: TrainCycleItem
# list of input tensors
self.fwd_in: List[torch.Tensor]
# list of buffer tensors
self.fwd_buffer: List[torch.Tensor]
# list of output tensors
self.fwd_out: List[torch.Tensor]
# sharding strategy
self._strategy = strategy
# target function
self._target = target
# compute shard_metainfo if possible
if self._strategy is not None and self._target is not None:
self.compute_shard_metainfo()
@property
def strategy(self) -> ShardingStrategy:
return self._strategy
@property
def target(self) -> Callable:
return self._target
@strategy.setter
def strategy(self, strategy: ShardingStrategy) -> None:
self._strategy = strategy
if self._strategy is not None and self._target is not None:
self.compute_shard_metainfo()
@target.setter
def target(self, target: Callable) -> None:
self._target = target
if self._strategy is not None and self._target is not None:
self.compute_shard_metainfo()
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
"""
Compute sharded opdata based on the given data and sharding spec.
"""
if isinstance(sharding_spec, ShardingSpec):
op_data = OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
assert len(data) == len(sharding_spec), f"Length of data and sharding spec should be the same."
sharded_data = []
for d, s in zip(data, sharding_spec):
sharded_data.append(torch.zeros(s.get_sharded_shape_per_device(), device="meta"))
op_data = OperationData(name=operation_data.name, data=sharded_data, type=operation_data.type)
else:
raise ValueError(f"Sharding spec should be ShardingSpec or list, but got {type(sharding_spec)}.")
return op_data
def compute_shard_metainfo(self):
"""
Compute meta info based on sharding strategy and the given target function.
"""
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
# check whether the target in the list that we don't need to save activation
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
else:
# function
meta_func = meta_register.get(self._target)
# check whether the target in the list that we don't need to save activation
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
# construct args for meta_func
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
else:
kwargs = {'inplace': False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
# process corner case for NO_SAVE_ACTIVATION
if not save_fwd_in:
self.fwd_in = []
from typing import Dict, Tuple
from enum import Enum
import torch
from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .region import Region
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class AMPOptimizer(ColossalaiOptimizer):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
Args:
optimizer (Optimizer): An Optimizer instance.
module (BaseOffloadModule): A ``BaseOffloadModule`` instance.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""
def __init__(self,
optimizer: Optimizer,
module: BaseOffloadModule,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
clipping_norm: float = 0.0,
norm_type: float = 2.0):
super().__init__(optimizer)
self.module = module
self.optim_state = OptimState.UNSCALED
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm
self.region_manager: RegionManager = self.module.region_manager
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
self.param_to_region: Dict[torch.nn.Parameter, Region] = dict()
self.fp32_to_fp16_params: Dict[torch.Tensor, torch.nn.Parameter] = dict()
if self.clipping_flag:
assert norm_type == 2.0, "AMPOptimizer only supports L2 norm now"
self.__init__optimizer()
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._logger = get_dist_logger()
def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group['params']:
region = self.param_to_region[fake_param]
begin, end = self.param_to_range[fake_param]
fake_param.data = region.cpu_grad[begin:end]
fake_param.grad = fake_param.data
fake_param.data = region.fp32_data[begin:end]
def _update_fp16_params(self):
none_tensor = torch.empty([0])
for group in self.param_groups:
for fake_param in group['params']:
assert fake_param.grad is None
fake_param.data = none_tensor
self.param_to_region[fake_param].cpu_grad = None
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(self.module.overflow_counter.item())
return self._found_overflow.item() > 0
def _get_combined_scale(self):
loss_scale = 1
if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if combined_scale == 1:
return -1
else:
return combined_scale
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def zero_grad(self, *args, **kwargs):
self.module.overflow_counter = torch.cuda.IntTensor([0])
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
# Copy gradients from model params to main params.
self._set_grad_ptr()
found_inf = self._check_overflow()
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self.zero_grad()
self._update_fp16_params()
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
raise NotImplementedError
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.module.backward(loss)
def __init__optimizer(self):
for group in self.optim.param_groups:
fake_params_list = list()
for param in group['params']:
region = self.region_manager.get_region(param)
fake_param = torch.nn.Parameter(torch.empty([0]))
self.param_to_range[fake_param] = region.param_to_range[param]
self.param_to_region[fake_param] = region
fake_params_list.append(fake_param)
# Reset existing state dict key to the new main param.
if param in self.optim.state:
self.optim.state[fake_param] = self.optim.state.pop(param)
group['params'] = fake_params_list
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optim.load_state_dict(self.optim.state_dict())
\ No newline at end of file
from functools import partial
from typing import Optional, Set
import torch
import torch.nn as nn
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from .region_manager import RegionManager
from .util import GlobalRuntimeInfo
class BaseOffloadModule:
"""
BaseOffloadModule: A model wrapper for parameter offloading.
Args:
model (nn.Module): model to apply offloading.
region_manager (RegionManager): a ``RegionManager`` instance.
is_sync (bool): synchronous mode or not.
"""
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
self.model = model
self.region_manager = region_manager
self.grad_hook_list = []
self.overflow_counter = torch.cuda.IntTensor([0])
self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream
self._cast_buffers()
def register_grad_hook(self):
for p in self.model.parameters():
if p.requires_grad:
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
def remove_grad_hook(self):
for hook in self.grad_hook_list:
hook.remove()
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _pre_forward(self):
self.register_grad_hook()
for region in self.region_manager.region_list:
region.cpu_grad = None
def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.model.zero_grad(set_to_none=True)
self._pre_forward()
outputs = self.model(*args, **kwargs)
return outputs
def backward(self, loss):
loss.backward()
self._post_backward()
def _post_backward(self):
torch.cuda.synchronize()
self.remove_grad_hook()
for p in self.model.parameters():
p.grad = None
GlobalRuntimeInfo().fwd_prefetch_event_map.clear()
GlobalRuntimeInfo().bwd_prefetch_event_map.clear()
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
region = self.region_manager.get_region(p)
region.copy_grad_to_region_slice(p, grad)
if region.can_release:
self.overflow_counter += region.has_inf_or_nan
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.grad_offload_stream):
GlobalRuntimeInfo().d2h_stream.wait_stream(master_stream)
region.move_grad_to_cpu()
return empty_grad
def _cast_buffers(self):
for buffer in self.model.buffers():
buffer.data = buffer.cuda()
def parameters(self, recurse: bool = True):
return self.model.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True):
return self.model.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True):
return self.model.named_buffers(prefix, recurse)
def named_children(self):
return self.model.named_children()
def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
return self.model.named_modules(memo, prefix, remove_duplicate)
from typing import Dict
import torch
import torch.fx
from torch.fx import GraphModule
from torch.utils._pytree import tree_map
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
def memory_optimize(model: torch.nn.Module,
inps: Dict[str, torch.Tensor],
memory_budget: float = -1.0,
solver_name: str = 'asyn'):
model = model.cpu().half()
tracer = ColoTracer()
assert is_compatible_with_meta()
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
meta_args = tree_map(wrap_fn, inps)
graph = tracer.trace(model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
interp = MetaInfoProp(gm)
interp.propagate(*meta_args.values())
region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
region_manager._build_regions()
GlobalRuntimeInfo().region_list = region_manager.region_list
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2
print(
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
)
if solver_name == 'syn':
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
elif solver_name == 'asyn':
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
else:
raise TypeError(f"Unknown solver name {solver_name}!")
gm.recompile()
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
return optimized_model
from typing import Dict, List, Tuple
import torch
from torch.fx import Node
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class Region:
"""
Region: A container owning a piece of contiguous nodes in the DNN computing graph.
Args:
r_id (int): the index of the region in the computing graph.
"""
def __init__(self, r_id: int = 0) -> None:
self.r_id: int = r_id
self.fp16_params: List[torch.nn.Parameter] = []
self.param_size: int = 0
self.shared_rid: int = self.r_id
self.param_num: int = 0
self.grad_num: int = 0
self.fp16_data = None
self.fp32_data = None
self.cpu_grad = None
self.temp_fp32_data = None
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
self.need_offload: bool = False
self.is_syn: bool = False
self.nodes: List[Node] = []
self.fwd_prefetch_region = None
self.bwd_prefetch_region = None
self.in_mem_pool_flag: bool = False
@property
def can_release(self) -> bool:
"""
Check if the region can be released.
"""
return self.grad_num == self.param_num
@property
def has_inf_or_nan(self) -> bool:
"""
Check if the grad of the region has inf or nan values on CUDA.
"""
return torch.isinf(self.fp16_data).any() | torch.isnan(self.fp16_data).any()
def init_param_data(self, pre_alloc_tensor: torch.Tensor = None):
"""
Map the parameters in the region to a contiguous memory space.
"""
self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
offset = 0
for param in self.fp16_params:
param.data = param.data.cuda()
p_num = param.data.numel()
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num)
offset += p_num
self.fp32_data = self.fp16_data.float().cpu().pin_memory()
free_storage(self.fp16_data)
if self.in_mem_pool_flag and pre_alloc_tensor is not None:
self.fp16_data = pre_alloc_tensor
def move_param_to_cuda(self):
"""
Move parameters from CPU to GPU.
It first moves float32 parameters to GPU and
then transforms float32 parameters to half-precision on the GPU.
The reason is that the performance of precision conversion on the CPU
is much slower than the data transfer overhead.
"""
self.temp_fp32_data.copy_(self.fp32_data, non_blocking=True)
self.temp_fp32_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
alloc_storage(self.fp16_data)
self.fp16_data[:self.param_num].copy_(self.temp_fp32_data)
self.fp16_data.record_stream(torch.cuda.current_stream())
self.__update_params_ptr()
def move_grad_to_cpu(self):
"""
Move gradients from GPU to CPU.
"""
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True)
self.fp16_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
self.free_cuda_data()
self.grad_num = 0
def free_cuda_data(self):
free_storage(self.fp16_data)
# torch.cuda.empty_cache()
def copy_grad_to_region_slice(self, param: torch.nn.Parameter, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the region.
Args:
param (torch.nn.Parameter): the param used to retrieve meta information
data_slice (torch.Tensor): the tensor to be copied to the region
"""
begin, end = self.param_to_range[param]
self.fp16_data[begin:end].copy_(data_slice.data.flatten())
param.data = self.fp16_data[begin:end].view(param.data.shape)
self.grad_num += data_slice.numel()
def split(self, cut_node_idx: int, cut_param_idx: int):
"""
Split the region into two and return the latter.
"""
new_reg = Region(r_id=self.r_id + 1)
new_reg.nodes = self.nodes[cut_node_idx:]
new_reg.fp16_params = self.fp16_params[cut_param_idx:]
for p in new_reg.fp16_params:
new_reg.param_size += p.data.numel() * p.data.element_size()
new_reg.param_num += p.data.numel()
self.nodes = self.nodes[:cut_node_idx]
self.fp16_params = self.fp16_params[:cut_param_idx]
self.param_size -= new_reg.param_size
self.param_num -= new_reg.param_num
return new_reg
def __update_params_ptr(self) -> None:
for param in self.fp16_params:
begin, end = self.param_to_range[param]
param.data = self.fp16_data[begin:end].view(param.data.shape)
from typing import List, Any, Dict, Tuple
import torch
from torch.fx import Graph, Node
from .solver import SolverFactory
from .training_simulator import TrainingSimulator
from .region import Region
from .util import NodeInfo
class RegionManager:
"""
RegionManager is used to construct and manage the offload plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
solver_name (str): a solver name which specifies the preferences for plan searching.
memory_budget (float): the given memory budget.
cnode (List[str], optional): Common node List, should be the subset of input.
"""
def __init__(self,
graph: Graph,
solver_name: str = 'asyn',
memory_budget: float = -1.0,
cnode: List[str] = None):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.cnode = cnode
self.only_param_ops = []
self.param_region_map: Dict[torch.nn.Parameter, Region] = dict()
self.shared_region_pairs: List[Tuple[Region, Region]] = list()
self.region_list: List[Region] = list()
self.rid_in_pool: List[int] = list()
self.mem_block_size: int = 0
self.memory_budget = memory_budget
self.solver_name = solver_name
self.require_pool: bool = solver_name == 'asyn'
self.reg_to_block: Dict[int, int] = dict()
def _build_regions(self):
"""
1. Pre-processing, mainly contains linearized computing graph and
merge smaller regions into larger ones.
2. Construct a solver to search for an efficient offload strategy.
3. Post-processing, mainly contains early region placement if using asynchronous mode,
and initialize region data.
"""
self._pre_process()
solver_cls = SolverFactory.create(self.solver_name)
solver = solver_cls(self.region_list, self.memory_budget)
solver._call_solver()
self._post_process(solver.best_ts)
def _pre_process(self):
init_region_list = self._linearize_graph()
if len(self.shared_region_pairs) > 1:
raise NotImplementedError(
'The current version only considers at most one pair of parameter sharing.')
elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0]
assert shared_regs[0].shared_rid == shared_regs[1].r_id \
and shared_regs[1].shared_rid == shared_regs[0].r_id
fst_id = shared_regs[0].r_id
lst_id = shared_regs[1].r_id
regs_left_out = init_region_list[:fst_id + 1]
regs_right_out = init_region_list[lst_id:]
hold_regs = init_region_list[fst_id + 1:lst_id]
else:
regs_left_out = []
regs_right_out = []
hold_regs = init_region_list
self.mem_block_size = self._search_block_size(hold_regs)
hold_regs = self._merge_small_regions(hold_regs)
if self.require_pool:
for reg in hold_regs:
reg.in_mem_pool_flag = True
self.rid_in_pool.append(reg.r_id)
self.region_list.extend(regs_left_out)
self.region_list.extend(hold_regs)
for reg in regs_right_out:
reg.r_id = self.region_list[-1].r_id + 1
self.region_list[reg.shared_rid].shared_rid = reg.r_id
self.region_list.append(reg)
self._process_shared_region()
self.max_param_num = max([reg.param_num for reg in self.region_list])
self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size()
def _post_process(self, ts: TrainingSimulator = None):
if self.require_pool:
self._early_region_placement(ts)
self._init_region_data()
def _early_region_placement(self, ts: TrainingSimulator):
"""
Implemented the early region placement strategy to avoid GPU memory fragmentation.
It maps all region data into a contiguous memory space and
reuses the same memory space for regions that do not coexist.
Args:
ts (TrainingSimulator): the best training simulator, which records region execution flow.
Raises:
NotImplementedError: due to the naive implementation,
it may not find a suitable region placement strategy for the given execution flow.
"""
reg_flow = torch.cat(
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
mem_block_num = torch.max(
torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(
ts.fwd_reg_flow, ts.bwd_reg_flow)
block_to_regs = {}
for block_idx in range(mem_block_num):
block_to_regs[block_idx] = []
for reg in self.region_list:
if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id]
cur_reg_coexists = torch.sum(
coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id)
self.reg_to_block[reg.r_id] = block_idx
break
if reg.r_id not in self.reg_to_block:
raise NotImplementedError(
f'can not find a block from the memory pool to store parameters of the region')
self.memory_pool = torch.chunk(torch.zeros(int(
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
"""
Merge smaller regions into larger ones for better bandwidth utilization and easier management.
It is inspired by Gemini.
Args:
orig_reg_list (List[Region]): original region list.
Returns:
List[Region]: region list after merging.
"""
r_id = orig_reg_list[0].r_id
region = Region(r_id=r_id)
region_list = [region]
for orig_reg in orig_reg_list:
if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size:
r_id += 1
region = Region(r_id=r_id)
region_list.append(region)
region.param_size += orig_reg.param_size
region.param_num += orig_reg.param_num
region.nodes.extend(orig_reg.nodes)
region.fp16_params.extend(orig_reg.fp16_params)
self.__update_param_region_map(orig_reg.fp16_params, region)
return region_list
def _search_block_size(self,
region_list: List[Region],
search_interval_byte: int = 1024,
search_range_byte: int = 128 * 1024 ** 2) -> int:
"""
Search for a suitable memory block size.
Args:
region_list (List[Region]): region list.
search_interval_byte (int): searching interval in byte.
search_range_byte (int): searching range in byte.
Returns:
int: the best memory block size.
"""
def _get_wasted_mem(size_list: List[int], blk_size: int):
"""
Get wasted byte for a certain block size.
"""
acc_wasted = 0
left = 0
for s in size_list:
if left + s > blk_size:
acc_wasted += blk_size - left
left = s
left += s
acc_wasted += blk_size - left
return acc_wasted
param_size_list = [
region.param_size for region in region_list if region.r_id == region.shared_rid]
start_size = max(param_size_list)
min_mem_waste = float('+inf')
best_block_size = start_size
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
temp_waste += _get_wasted_mem(param_size_list, block_size)
if temp_waste < min_mem_waste:
min_mem_waste = temp_waste
best_block_size = block_size
return best_block_size
def _init_region_data(self):
"""
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
for region in self.region_list:
pre_alloc_tensor = None
if self.require_pool and region.r_id in self.rid_in_pool:
block_idx = self.reg_to_block[region.r_id]
pre_alloc_tensor = self.memory_pool[block_idx]
if region.r_id <= region.shared_rid:
region.init_param_data(pre_alloc_tensor)
else:
shared_region = self.region_list[region.shared_rid]
region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
)
torch.cuda.empty_cache()
def _process_shared_region(self):
"""
Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge.
"""
if len(self.shared_region_pairs):
assert len(self.shared_region_pairs) <= 1
former_reg, latter_reg = self.shared_region_pairs[0]
assert latter_reg.param_num >= former_reg.param_num
embedding_node = former_reg.nodes[-1]
assert embedding_node.op == 'call_module' and isinstance(
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
if latter_reg.param_num > former_reg.param_num:
for idx, n in enumerate(latter_reg.nodes):
if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
torch.nn.Linear)) or \
(n.op == 'call_function' and n.target is torch.nn.functional.linear):
cut_node_idx = idx + 1
break
assert len(latter_reg.fp16_params) == 2
new_reg = latter_reg.split(cut_node_idx, 1)
for p in new_reg.fp16_params:
self.param_region_map[p] = new_reg
self.region_list.insert(new_reg.r_id, new_reg)
for reg in self.region_list[new_reg.r_id + 1:]:
reg.r_id += 1
latter_reg.shared_rid = former_reg.r_id
former_reg.shared_rid = latter_reg.r_id
def _linearize_graph(self) -> List[Region]:
"""Linearizing the graph
Args:
graph (Graph): The computing graph to be optimized.
Returns:
List[Region]: each region contains the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops and shape-consistency ops into the previous node.
"""
# List of target name that could be seen as common node
common_ops = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in common_ops
else:
return target.__name__ in common_ops
def _is_act(data: Any) -> bool:
"""Check if an op could be seen as parameter computation start
Args:
data (Any): meta_data
Returns:
bool
"""
label = False
if isinstance(data, torch.Tensor):
return True
elif isinstance(data, (tuple, list)):
for d in data:
label = label or _is_act(d)
return label
def _maybe_param_comp_start() -> bool:
"""Check if an op could be seen as parameter computation start
Args:
n (Node): node
Returns:
bool
"""
label = False
if n.op == "get_attr":
label = True
elif n.op == "call_module":
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
return label and not sum([v for _, v in param_op_deps.items()])
def _is_param_comp_end() -> bool:
"""Check if an op could be seen as parameter computation end
Args:
n (Node): node
Returns:
bool
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(
n.target), "inplace", False)
return inplace
label = False
if n.op == "call_module":
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
elif n.op == "call_function":
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
def _exception_node_handling():
# TODO meta info prop bug
if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
n.meta['fwd_out'] = []
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
self.cnode = []
node_id = 0
region_id = 0
param_op_deps = {}
deps = {}
region_list = []
region = Region(r_id=region_id)
act_n = None
for n in self.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n.all_input_nodes:
if n_par.op != "placeholder" and n_par.name not in self.cnode:
deps[n_par] -= 1
if n_par.op != "placeholder" and n_par.name in self.only_param_ops:
param_op_deps[n_par] -= 1
if act_n in region.nodes and _maybe_param_comp_start():
ns = []
border_n_idx = region.nodes.index(act_n)
if border_n_idx < len(region.nodes):
ns = region.nodes[border_n_idx + 1:]
region.nodes = region.nodes[:border_n_idx + 1]
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
region.nodes = ns
_exception_node_handling()
region.nodes.append(n)
self._set_node_and_region_info(node_id, n, region)
node_id += 1
# if the node could free all dependencies in graph
# we could begin a new region
if _is_param_comp_end():
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len(
[user for user in n.users if user.op != "output"])
# propagate param node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
]) or n.op == "get_attr":
self.only_param_ops.append(n.name)
param_op_deps[n] = len(
[user for user in n.users if user.op != "output"])
# record last activation node
if _is_act(n._meta_data):
act_n = n
if len(region.nodes):
region_list.append(region)
return region_list
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
cur_n.node_info = NodeInfo(node_id)
if cur_n.op == 'call_module':
target = cur_n.target
submod = self.root_module.get_submodule(target)
for p in list(submod.parameters(recurse=False)):
if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[p], cur_reg))
else:
self.param_region_map[p] = cur_reg
cur_reg.fp16_params.append(p)
cur_reg.param_num += p.data.numel()
cur_reg.param_size += p.data.numel() * p.data.element_size()
elif cur_n.op == "get_attr":
attr_itr = self.root_module
atoms = cur_n.target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.Parameter):
if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[attr_itr], cur_reg))
else:
self.param_region_map[attr_itr] = cur_reg
cur_reg.fp16_params.append(attr_itr)
cur_reg.param_num += attr_itr.data.numel()
cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size()
def get_region(self, param: torch.nn.Parameter) -> Region:
"""
Return the region owning the parameter.
Args:
param (torch.nn.Parameter): a torch parameter object
"""
return self.param_region_map[param]
def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region):
for p in params:
self.param_region_map[p] = region
from typing import List
import torch
from torch.fx.node import Node
from .region import Region
from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd
class SynPreFwdPostBwdOP(torch.autograd.Function):
"""
A customized prefetch and offload operation.
Args:
input_: input tensor.
fwd_info: information dict, which contains region indices
that need to be uploaded or freed during forward pass.
bwd_info: information dict, which contains region indices
that need to be uploaded during backward pass.
"""
@staticmethod
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
d2h_rid = fwd_info.get('d2h_rid', None)
if d2h_rid is not None:
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region)
free_region.free_cuda_data()
h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region)
h2d_region.move_param_to_cuda()
return input_
@staticmethod
def backward(ctx, grad_output):
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
pref_region.move_param_to_cuda()
return grad_output, None, None
class AsynPreFwdPostBwdOP(torch.autograd.Function):
"""
A customized prefetch and offload operation.
Args:
input_: input tensor.
fwd_info: information dict, which contains region indices
that need to be prefetched, waited, or freed during forward pass.
bwd_info: information dict, which contains region indices
that need to be prefetched or waited during backward pass.
"""
@staticmethod
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
sync_rid = fwd_info.get('sync_rid', None)
if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda()
prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event
return input_
@staticmethod
def backward(ctx, grad_output):
sync_rid = ctx.bwd_info.get('sync_rid', None)
if sync_rid is not None:
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region)
prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
else:
wait_region.move_param_to_cuda()
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda()
prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event
return grad_output, None, None
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
Convert Upload and Offload operation into runtime action.
Argument:
tensor(torch.Tensor): input tensor.
fwd_info(dict): information dict, which contains region indices
that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass.
'''
with torch._C.DisableTorchFunction():
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
Convert Prefetch and Offload operation into runtime action.
Argument:
tensor(torch.Tensor): input tensor.
fwd_info(dict): information dict, which contains region indices
that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass.
'''
with torch._C.DisableTorchFunction():
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None):
user_list = list(orig_node.users.keys())
if rep_user_nodes is not None:
user_list = rep_user_nodes
for user in user_list:
if user == inserted_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if orig_node in new_args:
# substitute the origin node with offload_apply_node
new_args[new_args.index(orig_node)] = inserted_node
user.args = tuple(new_args)
elif str(orig_node) in new_kwargs:
# substitute the origin node with offload_apply_node
new_kwargs[str(orig_node)] = inserted_node
user.kwargs = new_kwargs
def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
"""
This pass is used to add the synchronous upload and offload spec apply node to the origin graph.
"""
mod_graph = gm.graph
last_inp_node = tuple(mod_graph.nodes)[0]
for r_idx, region in enumerate(region_list):
# forward upload
fwd_info = {}
if requires_upload_p_in_fwd(region_list[region.shared_rid]):
fwd_info['h2d_rid'] = region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1
bwd_info = {}
# backward upload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
return gm
def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
"""
This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph.
"""
mod_graph = gm.graph
# upload parameters of the first region
last_inp_node = tuple(mod_graph.nodes)[0]
first_region_with_p = [region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node):
upload_apply_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, {}))
replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node
for r_idx, region in enumerate(region_list):
# forward prefetch
fwd_info = {}
if region.param_size:
fwd_info['sync_rid'] = region.r_id
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1
bwd_info = {}
# backward prefetch
if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['sync_rid'] = r_idx - 1
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
if region.bwd_prefetch_region:
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, {}, bwd_info))
replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular()
return gm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment