Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
...@@ -3,7 +3,7 @@ from typing import List, Tuple ...@@ -3,7 +3,7 @@ from typing import List, Tuple
import torch import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register from ..registry import meta_register
......
from typing import Callable, Dict, List, Tuple, Union from typing import List, Tuple
import torch import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register from ..registry import meta_register
__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info'] __all__ = ["batchnormnd_meta_info", "layernorm_meta_info"]
@meta_register.register(torch.nn.BatchNorm1d) @meta_register.register(torch.nn.BatchNorm1d)
...@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt ...@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# saved inv std and some other args indicating the status of the module # saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad # the bwd outputs are input grad, weight grad and bias grad
bwd_in_args = [ bwd_in_args = [
output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch output_tensor,
output_tensor,
weight_tensor,
mean_tensor,
var_tensor,
mean_tensor,
var_tensor,
1e-5,
num_batch,
] ]
bwd_out_args = [input_tensor, weight_tensor, bias_tensor] bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
...@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt ...@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# calculate memory cost # calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std # the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( fwd_memory_cost = MemoryCost(
[input_tensor, output_tensor, mean_tensor, var_tensor]), activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0, temp=0,
buffer=compute_size_in_bytes([mean_tensor, var_tensor])) buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
)
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean # the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase # and saved inv std during backward phase
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]), bwd_memory_cost = MemoryCost(
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), activation=compute_size_in_bytes([input_tensor]),
temp=compute_size_in_bytes([mean_tensor, var_tensor]), parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
buffer=compute_size_in_bytes([mean_tensor, var_tensor])) temp=compute_size_in_bytes([mean_tensor, var_tensor]),
buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
)
# total cost is the sum of forward and backward cost # total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, total_cost = MemoryCost(
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out # store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')] fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')] fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device='meta')] fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
...@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem ...@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
running_mean = torch.rand(input_tensor.shape[0], 1, device='meta') running_mean = torch.rand(input_tensor.shape[0], 1, device="meta")
running_var = torch.rand(input_tensor.shape[0], 1, device='meta') running_var = torch.rand(input_tensor.shape[0], 1, device="meta")
# construct args # construct args
fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor] fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor]
...@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem ...@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# memory cost # memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( fwd_memory_cost = MemoryCost(
[input_tensor, output_tensor, weight_tensor, bias_tensor]), activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0, temp=0,
buffer=compute_size_in_bytes([running_mean, running_var])) buffer=compute_size_in_bytes([running_mean, running_var]),
)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), bwd_memory_cost = MemoryCost(
temp=compute_size_in_bytes([running_mean, running_var]), activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
buffer=compute_size_in_bytes([running_mean, running_var])) parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([running_mean, running_var]),
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, buffer=compute_size_in_bytes([running_mean, running_var]),
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, )
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out # store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')] fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')] fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device='meta')] fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
...@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, ...@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
# store fwd_in, fwd_buffer, fwd_out # store fwd_in, fwd_buffer, fwd_out
fwd_in = [] fwd_in = []
fwd_buffer = [] fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')] fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
...@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, ...@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix])) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded # temp memory for backward is the index matrix to be discarded
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), bwd_mem_cost = MemoryCost(
temp=compute_size_in_bytes(index_matrix)) activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
temp=compute_size_in_bytes(index_matrix),
)
# total cost # total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp) total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
...@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, ...@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out # store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')] fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')] fwd_buffer = [torch.zeros_like(index_matrix, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device='meta')] fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
...@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple ...@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
import torch import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
...@@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f ...@@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, bwd_mem_cost = MemoryCost(
parameter=0, activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, parameter=0,
buffer=0) temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
buffer=0,
)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, total_mem_cost = MemoryCost(
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp, parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
...@@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f ...@@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# register torch.Tensor related metainfo # register torch.Tensor related metainfo
# (0, 0) # (0, 0)
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])(
torch.arange])(tensor_related_metainfo(0, 0)) tensor_related_metainfo(0, 0)
)
# (1, 0) # (1, 0)
meta_register.register([ meta_register.register(
torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute, [
torch.Tensor.split, torch.split, torch.Tensor.view torch.Tensor.flatten,
])(tensor_related_metainfo(1, 0)) torch.flatten,
torch.Tensor.transpose,
torch.transpose,
torch.Tensor.permute,
torch.permute,
torch.Tensor.split,
torch.split,
torch.Tensor.view,
]
)(tensor_related_metainfo(1, 0))
# (1, 1) # (1, 1)
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1)) meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register from ..registry import meta_register
...@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li ...@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase # gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor])) fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]), bwd_mem_cost = MemoryCost(
parameter=0, activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) - parameter=0,
activation_size([x_tensor, y_tensor]), temp=activation_size([output_tensor]) * 3
buffer=0) + activation_size([condition_tensor])
- activation_size([x_tensor, y_tensor]),
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, buffer=0,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, )
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
......
__all__ = ['Registry'] __all__ = ["Registry"]
class Registry: class Registry:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.store = {} self.store = {}
def register(self, source): def register(self, source):
def wrapper(func): def wrapper(func):
if isinstance(source, (list, tuple)): if isinstance(source, (list, tuple)):
# support register a list of items for this func # support register a list of items for this func
...@@ -21,7 +19,7 @@ class Registry: ...@@ -21,7 +19,7 @@ class Registry:
return wrapper return wrapper
def get(self, source): def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry' assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source] target = self.store[source]
return target return target
...@@ -29,4 +27,4 @@ class Registry: ...@@ -29,4 +27,4 @@ class Registry:
return source in self.store return source in self.store
meta_register = Registry('meta') meta_register = Registry("meta")
...@@ -2,20 +2,13 @@ from typing import Callable, List ...@@ -2,20 +2,13 @@ from typing import Callable, List
import torch import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register from .registry import meta_register
__all__ = ['ShardMetaInfo'] __all__ = ["ShardMetaInfo"]
class ShardMetaInfo: class ShardMetaInfo:
...@@ -76,10 +69,12 @@ class ShardMetaInfo: ...@@ -76,10 +69,12 @@ class ShardMetaInfo:
""" """
if isinstance(sharding_spec, ShardingSpec): if isinstance(sharding_spec, ShardingSpec):
op_data = OperationData(name=operation_data.name, op_data = OperationData(
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), name=operation_data.name,
type=operation_data.type, data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
logical_shape=operation_data.logical_shape) type=operation_data.type,
logical_shape=operation_data.logical_shape,
)
elif isinstance(sharding_spec, (list, tuple)): elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}." assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
...@@ -97,8 +92,9 @@ class ShardMetaInfo: ...@@ -97,8 +92,9 @@ class ShardMetaInfo:
""" """
Compute meta info based on sharding strategy and the given target function. Compute meta info based on sharding strategy and the given target function.
""" """
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \ assert meta_register.has(self._target.__class__) or meta_register.has(
f"Meta info for {self._target} is not registered." self._target
), f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__): if meta_register.has(self._target.__class__):
# module # module
meta_func = meta_register.get(self._target.__class__) meta_func = meta_register.get(self._target.__class__)
...@@ -117,11 +113,11 @@ class ShardMetaInfo: ...@@ -117,11 +113,11 @@ class ShardMetaInfo:
# construct kwargs # construct kwargs
if self.target in INPLACE_MODULE: if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace} kwargs = {"inplace": self.target.inplace}
elif self.target in INPLACE_OPS: elif self.target in INPLACE_OPS:
kwargs = {'inplace': True} kwargs = {"inplace": True}
else: else:
kwargs = {'inplace': False} kwargs = {"inplace": False}
# compute metainfo with meta_func # compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs) self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
......
from typing import Dict, Tuple
from enum import Enum from enum import Enum
from typing import Dict, Tuple
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .region import Region from .region import Region
from .region_manager import RegionManager
class OptimState(Enum): class OptimState(Enum):
SCALED = 0 SCALED = 0
UNSCALED = 1 UNSCALED = 1
class AMPOptimizer(ColossalaiOptimizer):
class AMPOptimizer(OptimizerWrapper):
""" """
A wrapper for Optimizer. A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
...@@ -36,19 +37,20 @@ class AMPOptimizer(ColossalaiOptimizer): ...@@ -36,19 +37,20 @@ class AMPOptimizer(ColossalaiOptimizer):
norm_type (float, optional): norm_type used for `clip_grad_norm`. norm_type (float, optional): norm_type used for `clip_grad_norm`.
""" """
def __init__(self, def __init__(
optimizer: Optimizer, self,
module: BaseOffloadModule, optimizer: Optimizer,
initial_scale: float = 2**16, module: BaseOffloadModule,
growth_factor: float = 2, initial_scale: float = 2**16,
backoff_factor: float = 0.5, growth_factor: float = 2,
growth_interval: int = 1000, backoff_factor: float = 0.5,
hysteresis: int = 2, growth_interval: int = 1000,
min_scale: float = 1, hysteresis: int = 2,
max_scale: float = 2**32, min_scale: float = 1,
clipping_norm: float = 0.0, max_scale: float = 2**32,
norm_type: float = 2.0): clipping_norm: float = 0.0,
norm_type: float = 2.0,
):
super().__init__(optimizer) super().__init__(optimizer)
self.module = module self.module = module
...@@ -68,19 +70,21 @@ class AMPOptimizer(ColossalaiOptimizer): ...@@ -68,19 +70,21 @@ class AMPOptimizer(ColossalaiOptimizer):
self.__init__optimizer() self.__init__optimizer()
# Grad scaler # Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, self.grad_scaler = DynamicGradScaler(
min_scale=min_scale, initial_scale=initial_scale,
growth_factor=growth_factor, min_scale=min_scale,
backoff_factor=backoff_factor, growth_factor=growth_factor,
growth_interval=growth_interval, backoff_factor=backoff_factor,
hysteresis=hysteresis, growth_interval=growth_interval,
max_scale=max_scale) hysteresis=hysteresis,
max_scale=max_scale,
)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._logger = get_dist_logger() self._logger = get_dist_logger()
def _set_grad_ptr(self): def _set_grad_ptr(self):
for group in self.param_groups: for group in self.param_groups:
for fake_param in group['params']: for fake_param in group["params"]:
region = self.param_to_region[fake_param] region = self.param_to_region[fake_param]
begin, end = self.param_to_range[fake_param] begin, end = self.param_to_range[fake_param]
...@@ -91,7 +95,7 @@ class AMPOptimizer(ColossalaiOptimizer): ...@@ -91,7 +95,7 @@ class AMPOptimizer(ColossalaiOptimizer):
def _update_fp16_params(self): def _update_fp16_params(self):
none_tensor = torch.empty([0]) none_tensor = torch.empty([0])
for group in self.param_groups: for group in self.param_groups:
for fake_param in group['params']: for fake_param in group["params"]:
assert fake_param.grad is None assert fake_param.grad is None
fake_param.data = none_tensor fake_param.data = none_tensor
self.param_to_region[fake_param].cpu_grad = None self.param_to_region[fake_param].cpu_grad = None
...@@ -129,10 +133,10 @@ class AMPOptimizer(ColossalaiOptimizer): ...@@ -129,10 +133,10 @@ class AMPOptimizer(ColossalaiOptimizer):
found_inf = self._check_overflow() found_inf = self._check_overflow()
if found_inf: if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step') self._logger.info(f"Found overflow. Skip step")
self.zero_grad() # reset all gradients self.zero_grad() # reset all gradients
self._update_fp16_params() self._update_fp16_params()
return return
...@@ -155,11 +159,10 @@ class AMPOptimizer(ColossalaiOptimizer): ...@@ -155,11 +159,10 @@ class AMPOptimizer(ColossalaiOptimizer):
self.module.backward(loss) self.module.backward(loss)
def __init__optimizer(self): def __init__optimizer(self):
for group in self.optim.param_groups: for group in self.optim.param_groups:
fake_params_list = list() fake_params_list = list()
for param in group['params']: for param in group["params"]:
region = self.region_manager.get_region(param) region = self.region_manager.get_region(param)
fake_param = torch.nn.Parameter(torch.empty([0])) fake_param = torch.nn.Parameter(torch.empty([0]))
self.param_to_range[fake_param] = region.param_to_range[param] self.param_to_range[fake_param] = region.param_to_range[param]
...@@ -170,8 +173,8 @@ class AMPOptimizer(ColossalaiOptimizer): ...@@ -170,8 +173,8 @@ class AMPOptimizer(ColossalaiOptimizer):
if param in self.optim.state: if param in self.optim.state:
self.optim.state[fake_param] = self.optim.state.pop(param) self.optim.state[fake_param] = self.optim.state.pop(param)
group['params'] = fake_params_list group["params"] = fake_params_list
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optim.load_state_dict(self.optim.state_dict()) self.optim.load_state_dict(self.optim.state_dict())
\ No newline at end of file
...@@ -4,7 +4,7 @@ from typing import Optional, Set ...@@ -4,7 +4,7 @@ from typing import Optional, Set
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from .region_manager import RegionManager from .region_manager import RegionManager
...@@ -22,7 +22,6 @@ class BaseOffloadModule: ...@@ -22,7 +22,6 @@ class BaseOffloadModule:
""" """
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True): def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
self.model = model self.model = model
self.region_manager = region_manager self.region_manager = region_manager
self.grad_hook_list = [] self.grad_hook_list = []
...@@ -91,17 +90,16 @@ class BaseOffloadModule: ...@@ -91,17 +90,16 @@ class BaseOffloadModule:
def parameters(self, recurse: bool = True): def parameters(self, recurse: bool = True):
return self.model.parameters(recurse) return self.model.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True): def named_parameters(self, prefix: str = "", recurse: bool = True):
return self.model.named_parameters(prefix, recurse) return self.model.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True): def named_buffers(self, prefix: str = "", recurse: bool = True):
return self.model.named_buffers(prefix, recurse) return self.model.named_buffers(prefix, recurse)
def named_children(self): def named_children(self):
return self.model.named_children() return self.model.named_children()
def named_modules(self, def named_modules(
memo: Optional[Set[torch.nn.Module]] = None, self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
prefix: str = '', ):
remove_duplicate: bool = True):
return self.model.named_modules(memo, prefix, remove_duplicate) return self.model.named_modules(memo, prefix, remove_duplicate)
...@@ -14,11 +14,9 @@ from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_ ...@@ -14,11 +14,9 @@ from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
def memory_optimize(model: torch.nn.Module, def memory_optimize(
inps: Dict[str, torch.Tensor], model: torch.nn.Module, inps: Dict[str, torch.Tensor], memory_budget: float = -1.0, solver_name: str = "asyn"
memory_budget: float = -1.0, ):
solver_name: str = 'asyn'):
model = model.cpu().half() model = model.cpu().half()
tracer = ColoTracer() tracer = ColoTracer()
assert is_compatible_with_meta() assert is_compatible_with_meta()
...@@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module, ...@@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module,
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}" f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
) )
if solver_name == 'syn': if solver_name == "syn":
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
elif solver_name == 'asyn': elif solver_name == "asyn":
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list) gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
else: else:
raise TypeError(f"Unknown solver name {solver_name}!") raise TypeError(f"Unknown solver name {solver_name}!")
gm.recompile() gm.recompile()
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn') optimized_model = BaseOffloadModule(gm, region_manager, solver_name == "syn")
return optimized_model return optimized_model
...@@ -55,13 +55,13 @@ class Region: ...@@ -55,13 +55,13 @@ class Region:
Map the parameters in the region to a contiguous memory space. Map the parameters in the region to a contiguous memory space.
""" """
self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda') self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device="cuda")
offset = 0 offset = 0
for param in self.fp16_params: for param in self.fp16_params:
param.data = param.data.cuda() param.data = param.data.cuda()
p_num = param.data.numel() p_num = param.data.numel()
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) self.fp16_data[offset : offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape) param.data = self.fp16_data[offset : offset + p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num) self.param_to_range[param] = (offset, offset + p_num)
offset += p_num offset += p_num
...@@ -83,7 +83,7 @@ class Region: ...@@ -83,7 +83,7 @@ class Region:
self.temp_fp32_data.record_stream(torch.cuda.current_stream()) self.temp_fp32_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag: if not self.in_mem_pool_flag:
alloc_storage(self.fp16_data) alloc_storage(self.fp16_data)
self.fp16_data[:self.param_num].copy_(self.temp_fp32_data) self.fp16_data[: self.param_num].copy_(self.temp_fp32_data)
self.fp16_data.record_stream(torch.cuda.current_stream()) self.fp16_data.record_stream(torch.cuda.current_stream())
self.__update_params_ptr() self.__update_params_ptr()
...@@ -94,7 +94,7 @@ class Region: ...@@ -94,7 +94,7 @@ class Region:
""" """
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True) self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True) self.cpu_grad.copy_(self.fp16_data[: self.param_num], non_blocking=True)
self.fp16_data.record_stream(torch.cuda.current_stream()) self.fp16_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag: if not self.in_mem_pool_flag:
self.free_cuda_data() self.free_cuda_data()
......
from typing import List, Any, Dict, Tuple from typing import Any, Dict, List, Tuple
import torch import torch
from torch.fx import Graph, Node from torch.fx import Graph, Node
from .region import Region
from .solver import SolverFactory from .solver import SolverFactory
from .training_simulator import TrainingSimulator from .training_simulator import TrainingSimulator
from .region import Region
from .util import NodeInfo from .util import NodeInfo
...@@ -19,14 +20,9 @@ class RegionManager: ...@@ -19,14 +20,9 @@ class RegionManager:
cnode (List[str], optional): Common node List, should be the subset of input. cnode (List[str], optional): Common node List, should be the subset of input.
""" """
def __init__(self, def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None):
graph: Graph,
solver_name: str = 'asyn',
memory_budget: float = -1.0,
cnode: List[str] = None):
self.graph = graph self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes) self.nodes = list(graph.nodes)
self.cnode = cnode self.cnode = cnode
...@@ -39,7 +35,7 @@ class RegionManager: ...@@ -39,7 +35,7 @@ class RegionManager:
self.memory_budget = memory_budget self.memory_budget = memory_budget
self.solver_name = solver_name self.solver_name = solver_name
self.require_pool: bool = solver_name == 'asyn' self.require_pool: bool = solver_name == "asyn"
self.reg_to_block: Dict[int, int] = dict() self.reg_to_block: Dict[int, int] = dict()
...@@ -61,22 +57,19 @@ class RegionManager: ...@@ -61,22 +57,19 @@ class RegionManager:
self._post_process(solver.best_ts) self._post_process(solver.best_ts)
def _pre_process(self): def _pre_process(self):
init_region_list = self._linearize_graph() init_region_list = self._linearize_graph()
if len(self.shared_region_pairs) > 1: if len(self.shared_region_pairs) > 1:
raise NotImplementedError( raise NotImplementedError("The current version only considers at most one pair of parameter sharing.")
'The current version only considers at most one pair of parameter sharing.')
elif len(self.shared_region_pairs) == 1: elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0] shared_regs = self.shared_region_pairs[0]
assert shared_regs[0].shared_rid == shared_regs[1].r_id \ assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id
and shared_regs[1].shared_rid == shared_regs[0].r_id
fst_id = shared_regs[0].r_id fst_id = shared_regs[0].r_id
lst_id = shared_regs[1].r_id lst_id = shared_regs[1].r_id
regs_left_out = init_region_list[:fst_id + 1] regs_left_out = init_region_list[: fst_id + 1]
regs_right_out = init_region_list[lst_id:] regs_right_out = init_region_list[lst_id:]
hold_regs = init_region_list[fst_id + 1:lst_id] hold_regs = init_region_list[fst_id + 1 : lst_id]
else: else:
regs_left_out = [] regs_left_out = []
regs_right_out = [] regs_right_out = []
...@@ -122,12 +115,9 @@ class RegionManager: ...@@ -122,12 +115,9 @@ class RegionManager:
it may not find a suitable region placement strategy for the given execution flow. it may not find a suitable region placement strategy for the given execution flow.
""" """
reg_flow = torch.cat( reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
mem_block_num = torch.max( coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow)
torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(
ts.fwd_reg_flow, ts.bwd_reg_flow)
block_to_regs = {} block_to_regs = {}
for block_idx in range(mem_block_num): for block_idx in range(mem_block_num):
...@@ -135,8 +125,7 @@ class RegionManager: ...@@ -135,8 +125,7 @@ class RegionManager:
for reg in self.region_list: for reg in self.region_list:
if reg.r_id in self.rid_in_pool: if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id] cur_reg_appears = coexist_matrix[:, reg.r_id]
cur_reg_coexists = torch.sum( cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool()
coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num): for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]): if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id) block_to_regs[block_idx].append(reg.r_id)
...@@ -145,9 +134,12 @@ class RegionManager: ...@@ -145,9 +134,12 @@ class RegionManager:
if reg.r_id not in self.reg_to_block: if reg.r_id not in self.reg_to_block:
raise NotImplementedError( raise NotImplementedError(
f'can not find a block from the memory pool to store parameters of the region') f"can not find a block from the memory pool to store parameters of the region"
self.memory_pool = torch.chunk(torch.zeros(int( )
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num)) self.memory_pool = torch.chunk(
torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"),
chunks=int(mem_block_num),
)
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
""" """
...@@ -178,10 +170,9 @@ class RegionManager: ...@@ -178,10 +170,9 @@ class RegionManager:
return region_list return region_list
def _search_block_size(self, def _search_block_size(
region_list: List[Region], self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2
search_interval_byte: int = 1024, ) -> int:
search_range_byte: int = 128 * 1024 ** 2) -> int:
""" """
Search for a suitable memory block size. Search for a suitable memory block size.
...@@ -208,11 +199,10 @@ class RegionManager: ...@@ -208,11 +199,10 @@ class RegionManager:
acc_wasted += blk_size - left acc_wasted += blk_size - left
return acc_wasted return acc_wasted
param_size_list = [ param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid]
region.param_size for region in region_list if region.r_id == region.shared_rid]
start_size = max(param_size_list) start_size = max(param_size_list)
min_mem_waste = float('+inf') min_mem_waste = float("+inf")
best_block_size = start_size best_block_size = start_size
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
...@@ -229,7 +219,7 @@ class RegionManager: ...@@ -229,7 +219,7 @@ class RegionManager:
Initialize region data, which maps the parameters in the region to a contiguous memory space. Initialize region data, which maps the parameters in the region to a contiguous memory space.
""" """
self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32) self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32)
for region in self.region_list: for region in self.region_list:
pre_alloc_tensor = None pre_alloc_tensor = None
...@@ -244,8 +234,7 @@ class RegionManager: ...@@ -244,8 +234,7 @@ class RegionManager:
region.fp16_data = shared_region.fp16_data region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range region.param_to_range = shared_region.param_to_range
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach( region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach()
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -259,13 +248,14 @@ class RegionManager: ...@@ -259,13 +248,14 @@ class RegionManager:
former_reg, latter_reg = self.shared_region_pairs[0] former_reg, latter_reg = self.shared_region_pairs[0]
assert latter_reg.param_num >= former_reg.param_num assert latter_reg.param_num >= former_reg.param_num
embedding_node = former_reg.nodes[-1] embedding_node = former_reg.nodes[-1]
assert embedding_node.op == 'call_module' and isinstance( assert embedding_node.op == "call_module" and isinstance(
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding) self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding
)
if latter_reg.param_num > former_reg.param_num: if latter_reg.param_num > former_reg.param_num:
for idx, n in enumerate(latter_reg.nodes): for idx, n in enumerate(latter_reg.nodes):
if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target), if (
torch.nn.Linear)) or \ n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear)
(n.op == 'call_function' and n.target is torch.nn.functional.linear): ) or (n.op == "call_function" and n.target is torch.nn.functional.linear):
cut_node_idx = idx + 1 cut_node_idx = idx + 1
break break
assert len(latter_reg.fp16_params) == 2 assert len(latter_reg.fp16_params) == 2
...@@ -273,7 +263,7 @@ class RegionManager: ...@@ -273,7 +263,7 @@ class RegionManager:
for p in new_reg.fp16_params: for p in new_reg.fp16_params:
self.param_region_map[p] = new_reg self.param_region_map[p] = new_reg
self.region_list.insert(new_reg.r_id, new_reg) self.region_list.insert(new_reg.r_id, new_reg)
for reg in self.region_list[new_reg.r_id + 1:]: for reg in self.region_list[new_reg.r_id + 1 :]:
reg.r_id += 1 reg.r_id += 1
latter_reg.shared_rid = former_reg.r_id latter_reg.shared_rid = former_reg.r_id
former_reg.shared_rid = latter_reg.r_id former_reg.shared_rid = latter_reg.r_id
...@@ -344,8 +334,8 @@ class RegionManager: ...@@ -344,8 +334,8 @@ class RegionManager:
target = n.target target = n.target
submod = self.root_module.get_submodule(target) submod = self.root_module.get_submodule(target)
if ( if (
len(list(submod.named_parameters(recurse=False))) != 0 len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0 or len(list(submod.named_buffers(recurse=False))) != 0
): ):
label = True label = True
...@@ -362,14 +352,12 @@ class RegionManager: ...@@ -362,14 +352,12 @@ class RegionManager:
""" """
def _is_inplace(n: Node): def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node`` """Get the inplace argument from ``torch.fx.Node``"""
"""
inplace = False inplace = False
if n.op == "call_function": if n.op == "call_function":
inplace = n.kwargs.get("inplace", False) inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module": elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule( inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
n.target), "inplace", False)
return inplace return inplace
label = False label = False
...@@ -378,28 +366,30 @@ class RegionManager: ...@@ -378,28 +366,30 @@ class RegionManager:
target = n.target target = n.target
submod = self.root_module.get_submodule(target) submod = self.root_module.get_submodule(target)
if ( if (
len(list(submod.named_parameters(recurse=False))) != 0 len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0 or len(list(submod.named_buffers(recurse=False))) != 0
): ):
label = True label = True
elif n.op == "call_function": elif n.op == "call_function":
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any( label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)) map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)
)
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users)) return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
def _exception_node_handling(): def _exception_node_handling():
# TODO meta info prop bug # TODO meta info prop bug
if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2: if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2:
n.meta['fwd_out'] = [] n.meta["fwd_out"] = []
# make sure that item in cnode is valid # make sure that item in cnode is valid
if self.cnode: if self.cnode:
for name in self.cnode: for name in self.cnode:
try: try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ assert (
f"Common node {name} is not an input of the model." next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
), f"Common node {name} is not an input of the model."
except StopIteration: except StopIteration:
raise ValueError(f"Common node name {name} not in graph.") raise ValueError(f"Common node name {name} not in graph.")
else: else:
...@@ -428,8 +418,8 @@ class RegionManager: ...@@ -428,8 +418,8 @@ class RegionManager:
ns = [] ns = []
border_n_idx = region.nodes.index(act_n) border_n_idx = region.nodes.index(act_n)
if border_n_idx < len(region.nodes): if border_n_idx < len(region.nodes):
ns = region.nodes[border_n_idx + 1:] ns = region.nodes[border_n_idx + 1 :]
region.nodes = region.nodes[:border_n_idx + 1] region.nodes = region.nodes[: border_n_idx + 1]
region_list.append(region) region_list.append(region)
region_id += 1 region_id += 1
region = Region(r_id=region_id) region = Region(r_id=region_id)
...@@ -448,19 +438,21 @@ class RegionManager: ...@@ -448,19 +438,21 @@ class RegionManager:
region = Region(r_id=region_id) region = Region(r_id=region_id)
# propagate common node attr if possible # propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode if len(n.all_input_nodes) == len(
]) or _is_cop(n.target): [node for node in n.all_input_nodes if node.name in self.cnode]
) or _is_cop(n.target):
self.cnode.append(n.name) self.cnode.append(n.name)
else: else:
deps[n] = len( deps[n] = len([user for user in n.users if user.op != "output"])
[user for user in n.users if user.op != "output"])
# propagate param node attr if possible # propagate param node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops if (
]) or n.op == "get_attr": len(n.all_input_nodes)
== len([node for node in n.all_input_nodes if node.name in self.only_param_ops])
or n.op == "get_attr"
):
self.only_param_ops.append(n.name) self.only_param_ops.append(n.name)
param_op_deps[n] = len( param_op_deps[n] = len([user for user in n.users if user.op != "output"])
[user for user in n.users if user.op != "output"])
# record last activation node # record last activation node
if _is_act(n._meta_data): if _is_act(n._meta_data):
...@@ -472,19 +464,16 @@ class RegionManager: ...@@ -472,19 +464,16 @@ class RegionManager:
return region_list return region_list
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
cur_n.node_info = NodeInfo(node_id) cur_n.node_info = NodeInfo(node_id)
if cur_n.op == 'call_module': if cur_n.op == "call_module":
target = cur_n.target target = cur_n.target
submod = self.root_module.get_submodule(target) submod = self.root_module.get_submodule(target)
for p in list(submod.parameters(recurse=False)): for p in list(submod.parameters(recurse=False)):
if p in self.param_region_map: if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id self.param_region_map[p].shared_rid = cur_reg.r_id
self.shared_region_pairs.append( self.shared_region_pairs.append((self.param_region_map[p], cur_reg))
(self.param_region_map[p], cur_reg))
else: else:
self.param_region_map[p] = cur_reg self.param_region_map[p] = cur_reg
...@@ -499,12 +488,10 @@ class RegionManager: ...@@ -499,12 +488,10 @@ class RegionManager:
attr_itr = getattr(attr_itr, atom) attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.Parameter): if isinstance(attr_itr, torch.nn.Parameter):
if attr_itr in self.param_region_map: if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
self.shared_region_pairs.append( self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg))
(self.param_region_map[attr_itr], cur_reg))
else: else:
self.param_region_map[attr_itr] = cur_reg self.param_region_map[attr_itr] = cur_reg
......
...@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function): ...@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_, fwd_info, bwd_info): def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info ctx.bwd_info = bwd_info
d2h_rid = fwd_info.get('d2h_rid', None) d2h_rid = fwd_info.get("d2h_rid", None)
if d2h_rid is not None: if d2h_rid is not None:
free_region = GlobalRuntimeInfo().region_list[d2h_rid] free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region) assert isinstance(free_region, Region)
free_region.free_cuda_data() free_region.free_cuda_data()
h2d_rid = fwd_info.get('h2d_rid', None) h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None: if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid] h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region) assert isinstance(h2d_region, Region)
...@@ -38,8 +38,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function): ...@@ -38,8 +38,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
h2d_rid = ctx.bwd_info.get("h2d_rid", None)
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None: if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid] pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region) assert isinstance(pref_region, Region)
...@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): ...@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def forward(ctx, input_, fwd_info, bwd_info): def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info ctx.bwd_info = bwd_info
sync_rid = fwd_info.get('sync_rid', None) sync_rid = fwd_info.get("sync_rid", None)
if sync_rid is not None: if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None) prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event: if prefetch_event:
prefetch_event.wait() prefetch_event.wait()
h2d_rid = fwd_info.get('h2d_rid', None) h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None: if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid] pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region) assert isinstance(pref_region, Region)
...@@ -87,8 +86,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): ...@@ -87,8 +86,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
sync_rid = ctx.bwd_info.get("sync_rid", None)
sync_rid = ctx.bwd_info.get('sync_rid', None)
if sync_rid is not None: if sync_rid is not None:
wait_region = GlobalRuntimeInfo().region_list[sync_rid] wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region) assert isinstance(wait_region, Region)
...@@ -98,7 +96,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): ...@@ -98,7 +96,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
else: else:
wait_region.move_param_to_cuda() wait_region.move_param_to_cuda()
h2d_rid = ctx.bwd_info.get('h2d_rid', None) h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None: if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid] pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region) assert isinstance(pref_region, Region)
...@@ -114,7 +112,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): ...@@ -114,7 +112,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
''' """
Convert Upload and Offload operation into runtime action. Convert Upload and Offload operation into runtime action.
Argument: Argument:
...@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): ...@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be uploaded, or freed during forward pass. that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass. that need to be uploaded during backward pass.
''' """
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
''' """
Convert Prefetch and Offload operation into runtime action. Convert Prefetch and Offload operation into runtime action.
Argument: Argument:
...@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): ...@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be prefetched, waited, or freed during forward pass. that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass. that need to be prefetched or waited during backward pass.
''' """
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret return ret
...@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R ...@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
# forward upload # forward upload
fwd_info = {} fwd_info = {}
if requires_upload_p_in_fwd(region_list[region.shared_rid]): if requires_upload_p_in_fwd(region_list[region.shared_rid]):
fwd_info['h2d_rid'] = region.r_id fwd_info["h2d_rid"] = region.r_id
# forward offload # forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload: if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1 fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {} bwd_info = {}
# backward upload # backward upload
if r_idx > 0 and region_list[r_idx - 1].need_offload: if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id
if fwd_info or bwd_info: if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node): with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', new_node = mod_graph.create_node(
convert_fwd_upload_bwd_offload_to_action, "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)
args=(last_inp_node, fwd_info, bwd_info)) )
replace_node_users(last_inp_node, new_node) replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1] last_inp_node = region.nodes[-1]
...@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ ...@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
first_region_with_p = [region for region in region_list if region.param_size][0] first_region_with_p = [region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id} fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node): with mod_graph.inserting_after(last_inp_node):
upload_apply_node = mod_graph.create_node('call_function', upload_apply_node = mod_graph.create_node(
convert_fwd_upload_bwd_offload_to_action, "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {})
args=(last_inp_node, fwd_info, {})) )
replace_node_users(last_inp_node, upload_apply_node) replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node last_inp_node = upload_apply_node
...@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ ...@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# forward prefetch # forward prefetch
fwd_info = {} fwd_info = {}
if region.param_size: if region.param_size:
fwd_info['sync_rid'] = region.r_id fwd_info["sync_rid"] = region.r_id
fwd_prefetch_region = region.fwd_prefetch_region fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]): if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id fwd_info["h2d_rid"] = fwd_prefetch_region.r_id
# forward offload # forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload: if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1 fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {} bwd_info = {}
# backward prefetch # backward prefetch
if r_idx > 0 and region_list[r_idx - 1].need_offload: if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['sync_rid'] = r_idx - 1 bwd_info["sync_rid"] = r_idx - 1
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region: if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id
if fwd_info or bwd_info: if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node): with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', new_node = mod_graph.create_node(
convert_fwd_prefetch_bwd_offload_to_action, "call_function",
args=(last_inp_node, fwd_info, bwd_info)) convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info),
)
replace_node_users(last_inp_node, new_node) replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1] last_inp_node = region.nodes[-1]
if region.bwd_prefetch_region: if region.bwd_prefetch_region:
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node): with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', new_node = mod_graph.create_node(
convert_fwd_prefetch_bwd_offload_to_action, "call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info)
args=(last_inp_node, {}, bwd_info)) )
replace_node_users(last_inp_node, new_node) replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular() # gm.graph.print_tabular()
return gm return gm
import time import time
from typing import List, Dict, Type
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Type
NOT_NVML = False NOT_NVML = False
try: try:
...@@ -10,10 +10,11 @@ except: ...@@ -10,10 +10,11 @@ except:
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
from .region import Region from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
from .util import NodeInfo, NvDevicePower from .util import NodeInfo, NvDevicePower
...@@ -49,19 +50,14 @@ class Solver(ABC): ...@@ -49,19 +50,14 @@ class Solver(ABC):
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time. It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
""" """
def __init__(self, def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None:
region_list: List[Region],
memory_budget: float = -1.0,
error_factor: float = 0.95) -> None:
self.region_list = region_list self.region_list = region_list
self.error_factor: float = error_factor self.error_factor: float = error_factor
if memory_budget > 0: if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor self.memory_budget = memory_budget * self.error_factor
else: else:
self.memory_budget = torch.cuda.get_device_properties( self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
get_current_device()).total_memory * self.error_factor
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power() self.comp_power: float = self._extract_computing_power()
...@@ -94,7 +90,7 @@ class Solver(ABC): ...@@ -94,7 +90,7 @@ class Solver(ABC):
if extra_cost == 0: if extra_cost == 0:
# means data transfer overhead can be completely overlapped # means data transfer overhead can be completely overlapped
return (float('inf'), total_mem_saving, peak_mem_saving) return (float("inf"), total_mem_saving, peak_mem_saving)
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving) return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool: def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
...@@ -122,9 +118,7 @@ class Solver(ABC): ...@@ -122,9 +118,7 @@ class Solver(ABC):
self.best_ts = best_ts self.best_ts = best_ts
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem) self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
def _update_node_mem_info(self, def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]):
fwd_mem_info: Dict[Node, float],
bwd_mem_info: Dict[Node, float]):
""" """
Update the runtime memory information of the node. Update the runtime memory information of the node.
...@@ -134,12 +128,10 @@ class Solver(ABC): ...@@ -134,12 +128,10 @@ class Solver(ABC):
""" """
for node, mem in fwd_mem_info.items(): for node, mem in fwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance( assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info, NodeInfo)
node.node_info.runtime_fwd_mem = mem node.node_info.runtime_fwd_mem = mem
for node, mem in bwd_mem_info.items(): for node, mem in bwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance( assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info, NodeInfo)
node.node_info.runtime_bwd_mem = mem node.node_info.runtime_bwd_mem = mem
def _extract_computing_power(self): def _extract_computing_power(self):
...@@ -159,12 +151,12 @@ class Solver(ABC): ...@@ -159,12 +151,12 @@ class Solver(ABC):
return NvDevicePower.RTX3080_FP16 * units return NvDevicePower.RTX3080_FP16 * units
elif device_name.__contains__("RTX 3090"): elif device_name.__contains__("RTX 3090"):
return NvDevicePower.RTX3090_FP16 * units return NvDevicePower.RTX3090_FP16 * units
elif device_name.__contains__('V100'): elif device_name.__contains__("V100"):
return NvDevicePower.V100_FP16 * units return NvDevicePower.V100_FP16 * units
elif device_name.__contains__("A100"): elif device_name.__contains__("A100"):
return NvDevicePower.A100_FP16 * units return NvDevicePower.A100_FP16 * units
else: else:
raise TypeError(f'Unknown NVIDIA GPU device name {device_name}') raise TypeError(f"Unknown NVIDIA GPU device name {device_name}")
def _profile_bandwidth(self): def _profile_bandwidth(self):
""" """
...@@ -172,9 +164,9 @@ class Solver(ABC): ...@@ -172,9 +164,9 @@ class Solver(ABC):
using data volumes ranging from 1KB to 1GB. using data volumes ranging from 1KB to 1GB.
""" """
print('profiling bandwidth ......') print("profiling bandwidth ......")
link_to_bandwidth = {} link_to_bandwidth = {}
links = ['h2d', 'd2h'] links = ["h2d", "d2h"]
for link in links: for link in links:
t_size = 1024 t_size = 1024
...@@ -182,24 +174,22 @@ class Solver(ABC): ...@@ -182,24 +174,22 @@ class Solver(ABC):
# from 1KB to 1GB # from 1KB to 1GB
for i in range(21): for i in range(21):
if link == 'h2d': if link == "h2d":
src_tensor = torch.ones( src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True)
int(t_size), dtype=torch.int8, pin_memory=True) dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda")
dst_tensor = torch.ones( elif link == "d2h":
(int(t_size)), dtype=torch.int8, device='cuda') src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda")
elif link == 'd2h': dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True)
src_tensor = torch.ones(
int(t_size), dtype=torch.int8, device='cuda')
dst_tensor = torch.ones(
(int(t_size)), dtype=torch.int8, pin_memory=True)
def func(): def func():
dst_tensor.copy_(src_tensor) dst_tensor.copy_(src_tensor)
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3) size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
print(f'size: {t_size / 1024 ** 2:.3f} MB, ' print(
f'{src_tensor.device.type}-to-{dst_tensor.device.type} ' f"size: {t_size / 1024 ** 2:.3f} MB, "
f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s') f"{src_tensor.device.type}-to-{dst_tensor.device.type} "
f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s"
)
t_size *= 2 t_size *= 2
...@@ -208,10 +198,7 @@ class Solver(ABC): ...@@ -208,10 +198,7 @@ class Solver(ABC):
class SynGreedySolver(Solver): class SynGreedySolver(Solver):
def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None:
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0) -> None:
super().__init__(region_list, memory_budget) super().__init__(region_list, memory_budget)
self.best_ts: SynTrainingSimulator = None self.best_ts: SynTrainingSimulator = None
...@@ -258,7 +245,8 @@ class SynGreedySolver(Solver): ...@@ -258,7 +245,8 @@ class SynGreedySolver(Solver):
else: else:
raise NotImplementedError( raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
)
def _call_solver_l2l(self): def _call_solver_l2l(self):
""" """
...@@ -270,7 +258,6 @@ class SynGreedySolver(Solver): ...@@ -270,7 +258,6 @@ class SynGreedySolver(Solver):
region.is_syn = True region.is_syn = True
def _try_to_offload(self, offload_region: Region): def _try_to_offload(self, offload_region: Region):
# record previous information # record previous information
orig_need_offload = offload_region.need_offload orig_need_offload = offload_region.need_offload
assert not orig_need_offload assert not orig_need_offload
...@@ -297,23 +284,17 @@ class SynGreedySolver(Solver): ...@@ -297,23 +284,17 @@ class SynGreedySolver(Solver):
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute() ts.execute()
extra_comm_cost = 2.0 * \ extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size)
ts._get_communication_overhead('h2d', offload_region.param_size)
# the shared region needs to be moved twice # the shared region needs to be moved twice
if offload_region.r_id < offload_region.shared_rid: if offload_region.r_id < offload_region.shared_rid:
extra_comm_cost *= 2.0 extra_comm_cost *= 2.0
profit = self._compute_offload_profit( profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit return ts, profit
class AsynGreedySolver(Solver): class AsynGreedySolver(Solver):
def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3):
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0,
search_window_size: int = 3):
super().__init__(region_list, memory_budget) super().__init__(region_list, memory_budget)
self.search_window_size = search_window_size self.search_window_size = search_window_size
...@@ -331,7 +312,7 @@ class AsynGreedySolver(Solver): ...@@ -331,7 +312,7 @@ class AsynGreedySolver(Solver):
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute() ts.execute()
self._update_state(ts) self._update_state(ts)
print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB") print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB")
def _call_solver(self): def _call_solver(self):
""" """
...@@ -358,18 +339,17 @@ class AsynGreedySolver(Solver): ...@@ -358,18 +339,17 @@ class AsynGreedySolver(Solver):
best_pref_ts = None best_pref_ts = None
# search when to prefetch the region offloaded # search when to prefetch the region offloaded
for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]: for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]:
if host_region.bwd_prefetch_region is not None: if host_region.bwd_prefetch_region is not None:
continue continue
temp_ts, profit = self._try_to_offload( temp_ts, profit = self._try_to_offload(host_region, region)
host_region, region)
if self._compare_profit(profit, max_prefetch_profit): if self._compare_profit(profit, max_prefetch_profit):
region_to_region_map[region.r_id] = host_region region_to_region_map[region.r_id] = host_region
max_prefetch_profit = profit max_prefetch_profit = profit
best_pref_ts = temp_ts best_pref_ts = temp_ts
if profit[0] == float('inf'): if profit[0] == float("inf"):
break break
if self._compare_profit(max_prefetch_profit, max_offload_profit): if self._compare_profit(max_prefetch_profit, max_offload_profit):
...@@ -392,7 +372,8 @@ class AsynGreedySolver(Solver): ...@@ -392,7 +372,8 @@ class AsynGreedySolver(Solver):
else: else:
raise NotImplementedError( raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
)
region_to_region_map.clear() region_to_region_map.clear()
...@@ -452,7 +433,6 @@ class AsynGreedySolver(Solver): ...@@ -452,7 +433,6 @@ class AsynGreedySolver(Solver):
peak_mem_saving = 0 peak_mem_saving = 0
while len(self.region_to_region_map) and peak_mem_saving <= 0: while len(self.region_to_region_map) and peak_mem_saving <= 0:
max_profit = (0,) max_profit = (0,)
best_ts = None best_ts = None
undo_host_region = None undo_host_region = None
...@@ -464,8 +444,7 @@ class AsynGreedySolver(Solver): ...@@ -464,8 +444,7 @@ class AsynGreedySolver(Solver):
assert offload_region.need_offload assert offload_region.need_offload
assert not offload_region.is_syn assert not offload_region.is_syn
ts, profit = self._try_convert_to_syn_upload(host_region, ts, profit = self._try_convert_to_syn_upload(host_region, offload_region)
offload_region)
if self._compare_profit(profit, max_profit): if self._compare_profit(profit, max_profit):
undo_host_region = host_region undo_host_region = host_region
...@@ -474,7 +453,7 @@ class AsynGreedySolver(Solver): ...@@ -474,7 +453,7 @@ class AsynGreedySolver(Solver):
best_ts = ts best_ts = ts
if best_ts is None: if best_ts is None:
raise NotImplementedError('repair error!') raise NotImplementedError("repair error!")
assert not undo_offload_region.is_syn assert not undo_offload_region.is_syn
undo_offload_region.is_syn = True undo_offload_region.is_syn = True
...@@ -500,17 +479,13 @@ class AsynGreedySolver(Solver): ...@@ -500,17 +479,13 @@ class AsynGreedySolver(Solver):
ts.execute() ts.execute()
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0) extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
profit = self._compute_offload_profit( profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit return ts, profit
class SolverFactory: class SolverFactory:
solvers: Dict[str, Type[Solver]] = { solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver}
'syn': SynGreedySolver,
'asyn': AsynGreedySolver
}
@staticmethod @staticmethod
def create(solver_name: str) -> Type[Solver]: def create(solver_name: str) -> Type[Solver]:
......
import bisect import bisect
from typing import List, Dict
from collections import OrderedDict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, List
from torch.fx.node import Node from torch.fx.node import Node
...@@ -26,10 +26,7 @@ class TrainingSimulator(ABC): ...@@ -26,10 +26,7 @@ class TrainingSimulator(ABC):
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth. link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
""" """
def __init__(self, def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
self.region_list = region_list self.region_list = region_list
self.region_num = len(region_list) self.region_num = len(region_list)
...@@ -87,11 +84,7 @@ class TrainingSimulator(ABC): ...@@ -87,11 +84,7 @@ class TrainingSimulator(ABC):
class SynTrainingSimulator(TrainingSimulator): class SynTrainingSimulator(TrainingSimulator):
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw) super().__init__(region_list, comp_power, link_to_bw)
def execute(self): def execute(self):
...@@ -115,8 +108,7 @@ class SynTrainingSimulator(TrainingSimulator): ...@@ -115,8 +108,7 @@ class SynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size self.runtime_mem += region.param_size
for node in region.nodes: for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \ self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
calculate_fwd_out(node)
self.fwd_node_mem[node] = self.runtime_mem self.fwd_node_mem[node] = self.runtime_mem
self.peak_mem = max(self.runtime_mem, self.peak_mem) self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
...@@ -141,18 +133,15 @@ class SynTrainingSimulator(TrainingSimulator): ...@@ -141,18 +133,15 @@ class SynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size self.runtime_mem += region.param_size
for node in region.nodes.__reversed__(): for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node) self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \ self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
node.meta['bwd_mem_out']
self.peak_mem = max(self.runtime_mem, self.peak_mem) self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch. # The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] + self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
calculate_fwd_tmp(node))
# free bwd_mem_out # free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes) self.bwd_node_deps[node] = len(node.all_input_nodes)
...@@ -160,12 +149,14 @@ class SynTrainingSimulator(TrainingSimulator): ...@@ -160,12 +149,14 @@ class SynTrainingSimulator(TrainingSimulator):
if user_node in self.bwd_node_deps: if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1 self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0: if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out'] self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0: if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " raise ValueError(
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" f"region id: {region.r_id}, node name: {node.name}, "
f"runtime memory computed less than 0, which is miscalculated!") f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!"
)
# release parameter and offload gradient in region # release parameter and offload gradient in region
if region.r_id == region.shared_rid: if region.r_id == region.shared_rid:
...@@ -177,23 +168,16 @@ class SynTrainingSimulator(TrainingSimulator): ...@@ -177,23 +168,16 @@ class SynTrainingSimulator(TrainingSimulator):
class AsynTrainingSimulator(TrainingSimulator): class AsynTrainingSimulator(TrainingSimulator):
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw) super().__init__(region_list, comp_power, link_to_bw)
self.iter_end_time: int = 0 self.iter_end_time: int = 0
# the last computation execution period # the last computation execution period
self.last_comp: ExecutionPeriod = ExecutionPeriod( self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
start_time=0, end_time=0)
# the last parameter prefetch execution period # the last parameter prefetch execution period
self.last_h2d: ExecutionPeriod = ExecutionPeriod( self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
start_time=0, end_time=0)
# the last gradient offload execution period # the last gradient offload execution period
self.last_d2h: ExecutionPeriod = ExecutionPeriod( self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
start_time=0, end_time=0)
# the forward computation execution period of the region # the forward computation execution period of the region
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the forward parameter prefetch execution period of the region # the forward parameter prefetch execution period of the region
...@@ -204,10 +188,8 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -204,10 +188,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the gradient offload execution period of the region # the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released # which is divided into those that are waiting and those that have been released
self.bwd_reg_to_offl_waiting: OrderedDict[int, self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict()
ExecutionPeriod] = OrderedDict() self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_freed: OrderedDict[int,
ExecutionPeriod] = OrderedDict()
# the region buffer, which records regions that are offloaded but not released # the region buffer, which records regions that are offloaded but not released
self.reg_buffer_to_free: List[int] = [] self.reg_buffer_to_free: List[int] = []
...@@ -217,10 +199,8 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -217,10 +199,8 @@ class AsynTrainingSimulator(TrainingSimulator):
# the region execution flow, # the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU # where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region. # when the execution reaches the i-th region.
self.fwd_reg_flow = torch.zeros( self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
(self.region_num, self.region_num)).bool() self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
self.bwd_reg_flow = torch.zeros(
(self.region_num, self.region_num)).bool()
def execute(self): def execute(self):
""" """
...@@ -232,7 +212,7 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -232,7 +212,7 @@ class AsynTrainingSimulator(TrainingSimulator):
for reg in self.region_list: for reg in self.region_list:
if reg.param_size and reg.r_id < self.region_num - 1: if reg.param_size and reg.r_id < self.region_num - 1:
for nr in self.region_list[reg.r_id + 1:]: for nr in self.region_list[reg.r_id + 1 :]:
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]): if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
reg.fwd_prefetch_region = nr reg.fwd_prefetch_region = nr
break break
...@@ -249,8 +229,7 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -249,8 +229,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.runtime_mem -= self.region_list[reg_id].param_size self.runtime_mem -= self.region_list[reg_id].param_size
self.bwd_reg_to_offl_waiting.clear() self.bwd_reg_to_offl_waiting.clear()
self.iter_end_time = max( self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time)
self.last_comp.end_time, self.last_d2h.end_time)
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
""" """
...@@ -258,10 +237,8 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -258,10 +237,8 @@ class AsynTrainingSimulator(TrainingSimulator):
""" """
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time) pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
pref_end_time = pref_start_time + \ pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size)
2.0 * self._get_communication_overhead('h2d', region.param_size) pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time)
pref_ep = ExecutionPeriod(
start_time=pref_start_time, end_time=pref_end_time)
if is_fwd: if is_fwd:
self.fwd_reg_to_pref[region.r_id] = pref_ep self.fwd_reg_to_pref[region.r_id] = pref_ep
else: else:
...@@ -276,18 +253,16 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -276,18 +253,16 @@ class AsynTrainingSimulator(TrainingSimulator):
if is_fwd: if is_fwd:
reg_to_comp = self.fwd_reg_to_comp reg_to_comp = self.fwd_reg_to_comp
reg_to_pref = self.fwd_reg_to_pref reg_to_pref = self.fwd_reg_to_pref
flop_key = 'fwd_flop' flop_key = "fwd_flop"
else: else:
reg_to_comp = self.bwd_reg_to_comp reg_to_comp = self.bwd_reg_to_comp
reg_to_pref = self.bwd_reg_to_pref reg_to_pref = self.bwd_reg_to_pref
flop_key = 'bwd_flop' flop_key = "bwd_flop"
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get( comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time)
region.r_id, ExecutionPeriod(0, 0)).end_time) comp_end_time = comp_start_time + sum(
comp_end_time = comp_start_time + \ [self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes]
sum([self._get_computing_overhead(node.meta.get(flop_key, 0)) )
for node in region.nodes]) comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time)
comp_ep = ExecutionPeriod(
start_time=comp_start_time, end_time=comp_end_time)
reg_to_comp[region.r_id] = comp_ep reg_to_comp[region.r_id] = comp_ep
self.last_comp = comp_ep self.last_comp = comp_ep
...@@ -297,10 +272,8 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -297,10 +272,8 @@ class AsynTrainingSimulator(TrainingSimulator):
""" """
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time) offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
offl_end_time = offl_start_time + \ offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size)
self._get_communication_overhead('d2h', region.param_size) offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time)
offl_ep = ExecutionPeriod(
start_time=offl_start_time, end_time=offl_end_time)
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
self.last_d2h = offl_ep self.last_d2h = offl_ep
...@@ -332,20 +305,17 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -332,20 +305,17 @@ class AsynTrainingSimulator(TrainingSimulator):
self.fwd_reg_flow[region.r_id, region.r_id] = True self.fwd_reg_flow[region.r_id, region.r_id] = True
else: else:
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1] self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
self.fwd_reg_flow[region.r_id, self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
self.reg_buffer_to_free] = False
self.reg_buffer_to_free.clear() self.reg_buffer_to_free.clear()
# prefetch parameters of the next region # prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self.runtime_mem += fwd_prefetch_region.param_size self.runtime_mem += fwd_prefetch_region.param_size
self.fwd_reg_flow[region.r_id, self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True
fwd_prefetch_region.r_id] = True
for node in region.nodes: for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \ self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
calculate_fwd_out(node)
self.peak_mem = max(self.runtime_mem, self.peak_mem) self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
...@@ -354,8 +324,7 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -354,8 +324,7 @@ class AsynTrainingSimulator(TrainingSimulator):
if region.need_offload: if region.need_offload:
self.runtime_mem -= region.param_size self.runtime_mem -= region.param_size
assert len( assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}"
self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
self.reg_buffer_to_free.append(region.r_id) self.reg_buffer_to_free.append(region.r_id)
def _eval_bwd_cost_per_region(self, region: Region): def _eval_bwd_cost_per_region(self, region: Region):
...@@ -398,8 +367,7 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -398,8 +367,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1] self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
else: else:
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1] self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
self.bwd_reg_flow[region.r_id, self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
self.reg_buffer_to_free] = False
# free gradients in the buffer # free gradients in the buffer
while len(self.reg_buffer_to_free): while len(self.reg_buffer_to_free):
...@@ -415,8 +383,7 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -415,8 +383,7 @@ class AsynTrainingSimulator(TrainingSimulator):
bwd_prefetch_region = region.bwd_prefetch_region bwd_prefetch_region = region.bwd_prefetch_region
if bwd_prefetch_region: if bwd_prefetch_region:
self.runtime_mem += bwd_prefetch_region.param_size self.runtime_mem += bwd_prefetch_region.param_size
self.bwd_reg_flow[region.r_id, self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True
bwd_prefetch_region.r_id] = True
# add the gradient of the parameter # add the gradient of the parameter
if region.r_id < region.shared_rid: if region.r_id < region.shared_rid:
...@@ -426,10 +393,8 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -426,10 +393,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size self.runtime_mem += region.param_size
for node in region.nodes.__reversed__(): for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node) self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \ self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
node.meta['bwd_mem_out']
self.peak_mem = max(self.runtime_mem, self.peak_mem) self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch. # The memory savings of a node may be negative due to parameter prefetch.
...@@ -437,8 +402,7 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -437,8 +402,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_node_mem[node] = self.runtime_mem self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] + self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
calculate_fwd_tmp(node))
# free bwd_mem_out # free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes) self.bwd_node_deps[node] = len(node.all_input_nodes)
...@@ -446,12 +410,14 @@ class AsynTrainingSimulator(TrainingSimulator): ...@@ -446,12 +410,14 @@ class AsynTrainingSimulator(TrainingSimulator):
if user_node in self.bwd_node_deps: if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1 self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0: if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out'] self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0: if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " raise ValueError(
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" f"region id: {region.r_id}, node name: {node.name}, "
f"runtime memory computed less than 0, which is miscalculated!") f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!"
)
# release parameters of the region # release parameters of the region
if requires_release_p_in_bwd(self.region_list[region.shared_rid]): if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
......
...@@ -35,7 +35,6 @@ class NvDevicePower: ...@@ -35,7 +35,6 @@ class NvDevicePower:
class GlobalRuntimeInfo(metaclass=SingletonMeta): class GlobalRuntimeInfo(metaclass=SingletonMeta):
def __init__(self): def __init__(self):
self.h2d_stream = torch.cuda.Stream() self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream() self.d2h_stream = torch.cuda.Stream()
...@@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: ...@@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
# forward # forward
for region in region_list: for region in region_list:
for node in region.nodes: for node in region.nodes:
runtime_mem = runtime_mem + \ runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
calculate_fwd_tmp(node) + calculate_fwd_out(node)
act_peak_mem = max(runtime_mem, act_peak_mem) act_peak_mem = max(runtime_mem, act_peak_mem)
# backward # backward
bwd_deps = {} bwd_deps = {}
for region in region_list.__reversed__(): for region in region_list.__reversed__():
for node in region.nodes.__reversed__(): for node in region.nodes.__reversed__():
runtime_mem -= calculate_fwd_out(node) runtime_mem -= calculate_fwd_out(node)
runtime_mem = runtime_mem + \ runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out']
act_peak_mem = max(runtime_mem, act_peak_mem) act_peak_mem = max(runtime_mem, act_peak_mem)
runtime_mem = runtime_mem - \ runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node)
node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node)
# free bwd_mem_out # free bwd_mem_out
bwd_deps[node] = len(node.all_input_nodes) bwd_deps[node] = len(node.all_input_nodes)
...@@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: ...@@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
if user_node in bwd_deps: if user_node in bwd_deps:
bwd_deps[user_node] -= 1 bwd_deps[user_node] -= 1
if bwd_deps[user_node] <= 0: if bwd_deps[user_node] <= 0:
runtime_mem -= user_node.meta['bwd_mem_out'] runtime_mem -= user_node.meta["bwd_mem_out"]
return act_peak_mem return act_peak_mem
...@@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float: ...@@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float:
def requires_upload_p_in_fwd(shared_reg: Region): def requires_upload_p_in_fwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid return (shared_reg.r_id >= shared_reg.shared_rid) or (
and shared_reg.need_offload) shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
)
def requires_release_p_in_bwd(shared_reg: Region): def requires_release_p_in_bwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid return (shared_reg.r_id >= shared_reg.shared_rid) or (
and shared_reg.need_offload) shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
)
def requires_offload_g_in_bwd(region: Region): def requires_offload_g_in_bwd(region: Region):
......
...@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec ...@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager() shape_consistency_manager = ShapeConsistencyManager()
def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, def _construct_shard_meta_info(
target_sharding_spec: ShardingSpec) -> ShardMetaInfo: node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec
) -> ShardMetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager # get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec) origin_sharding_spec, target_sharding_spec
)
meta_info = ShardMetaInfo() meta_info = ShardMetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo # get mem cost for ShardMetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length # extract user that has _meta_data and extract element length
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data"))
element_length = input_node._meta_data.element_size() element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length mem_cost.fwd.activation *= element_length
...@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, ...@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
meta_info.memory_cost = mem_cost meta_info.memory_cost = mem_cost
# get computation cost for ShardMetaInfo # get computation cost for ShardMetaInfo
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, meta_info.compute_cost = TrainCycleItem(
total_cost['backward'] * element_length, total_cost["forward"] * element_length,
total_cost['total'] * element_length) total_cost["backward"] * element_length,
total_cost["total"] * element_length,
)
# get tensor shape for ShardMetaInfo # get tensor shape for ShardMetaInfo
origin_sharding_spec: ShardingSpec origin_sharding_spec: ShardingSpec
...@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, ...@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
input_shape = origin_sharding_spec.get_sharded_shape_per_device() input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device() output_shape = target_sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')] meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = [] meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')] meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
return meta_info return meta_info
...@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) - ...@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
# extract node index and user node index # extract node index and user node index
args = node.args args = node.args
node_index, user_node_index = args[3], args[4] node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ origin_sharding_spec, target_sharding_spec = (
user_node_index] origin_spec_dict[node_index],
sharding_spec_dict[node_index][user_node_index],
)
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
...@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S ...@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
# this case is for all_reduce, there will be no memory cost # this case is for all_reduce, there will be no memory cost
meta_info = ShardMetaInfo() meta_info = ShardMetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
output_node = next(n for n in node.users if hasattr(n, '_meta_data')) output_node = next(n for n in node.users if hasattr(n, "_meta_data"))
element_length = output_node._meta_data.element_size() element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost() total_cost = comm_action.comm_spec.get_comm_cost()
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, meta_info.compute_cost = TrainCycleItem(
total_cost['backward'] * element_length, total_cost["forward"] * element_length,
total_cost['total'] * element_length) total_cost["backward"] * element_length,
total_cost["total"] * element_length,
)
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')] meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = [] meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')] meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
else: else:
# this case will be handled by shape consistency manager # this case will be handled by shape consistency manager
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ origin_sharding_spec, target_sharding_spec = (
'tgt_spec'] comm_action.comm_spec["src_spec"],
comm_action.comm_spec["tgt_spec"],
)
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info return meta_info
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, def comm_metainfo_pass(
comm_actions_dict: Dict) -> GraphModule: gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict
) -> GraphModule:
""" """
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
""" """
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.target == runtime_apply: if node.target == runtime_apply:
setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply: elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else: else:
pass pass
return gm return gm
...@@ -21,16 +21,15 @@ def _normalize_tuple(x): ...@@ -21,16 +21,15 @@ def _normalize_tuple(x):
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
class MetaInfoProp: class MetaInfoProp:
def __init__(self, module: GraphModule) -> None: def __init__(self, module: GraphModule) -> None:
self.module = module self.module = module
self.func_dict = { self.func_dict = {
'placeholder': self.placeholder_handler, "placeholder": self.placeholder_handler,
'get_attr': self.get_attr_handler, "get_attr": self.get_attr_handler,
'output': self.output_handler, "output": self.output_handler,
'call_function': self.node_handler, "call_function": self.node_handler,
'call_module': self.node_handler, "call_module": self.node_handler,
'call_method': self.node_handler, "call_method": self.node_handler,
} }
def _set_data_ptr(self, x): def _set_data_ptr(self, x):
...@@ -46,7 +45,7 @@ class MetaInfoProp: ...@@ -46,7 +45,7 @@ class MetaInfoProp:
""" """
Check if the node is inplace operation. Check if the node is inplace operation.
""" """
if node.op == 'call_module': if node.op == "call_module":
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function": elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS return node.target in OUTPUT_SAVED_OPS
...@@ -66,7 +65,7 @@ class MetaInfoProp: ...@@ -66,7 +65,7 @@ class MetaInfoProp:
Handle the placeholder node. Handle the placeholder node.
""" """
graph_info = GraphInfo() graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None)) out = _normalize_tuple(getattr(node, "_meta_data", None))
graph_info.fwd_out = list(out) if out[0] is not None else [] graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)} node.meta = {**asdict(graph_info)}
...@@ -96,7 +95,7 @@ class MetaInfoProp: ...@@ -96,7 +95,7 @@ class MetaInfoProp:
""" """
Handle other kind of nodes Handle other kind of nodes
""" """
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo() graph_info = GraphInfo()
meta_info = node.best_strategy_info meta_info = node.best_strategy_info
meta_info: ShardMetaInfo meta_info: ShardMetaInfo
...@@ -126,7 +125,8 @@ class MetaInfoProp: ...@@ -126,7 +125,8 @@ class MetaInfoProp:
for tensor in par.meta.get("fwd_out", []): for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor tensor: torch.Tensor
target_input_tensor = next( target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None
)
if target_input_tensor is not None: if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr target_input_tensor.data_ptr = tensor.data_ptr
...@@ -148,7 +148,7 @@ class MetaInfoProp: ...@@ -148,7 +148,7 @@ class MetaInfoProp:
graph_info.fwd_tmp = buffer_tensors graph_info.fwd_tmp = buffer_tensors
graph_info.fwd_out = output_tensors graph_info.fwd_out = output_tensors
# fetch other memory informations # fetch other memory information
memory_cost = meta_info.memory_cost memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.fwd_mem_out = memory_cost.fwd.activation graph_info.fwd_mem_out = memory_cost.fwd.activation
......
from copy import deepcopy
from typing import Dict, List from typing import Dict, List
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
CommAction,
CommType,
OperationData,
OperationDataType,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import CommSpec from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
...@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i ...@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, def runtime_apply_for_iterable_object(
user_node_index: int): node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int
):
""" """
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form. is converted into the user node expected form.
""" """
rst = [] rst = []
for index, (origin_sharding_spec, for index, (origin_sharding_spec, target_sharding_spec) in enumerate(
target_sharding_spec) in enumerate(zip(origin_dict[node_index], zip(origin_dict[node_index], input_dict[node_index][user_node_index])
input_dict[node_index][user_node_index])): ):
rst.append( rst.append(
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec, shape_consistency_manager.apply_for_autoparallel_runtime(
target_sharding_spec)) node[index], origin_sharding_spec, target_sharding_spec
)
)
rst = type(node)(rst) rst = type(node)(rst)
return rst return rst
...@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_ ...@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
if isinstance(comm_action.comm_spec, CommSpec): if isinstance(comm_action.comm_spec, CommSpec):
rst = comm_action.comm_spec.covert_spec_to_action(tensor) rst = comm_action.comm_spec.covert_spec_to_action(tensor)
else: else:
origin_sharding_spec = comm_action.comm_spec['src_spec'] origin_sharding_spec = comm_action.comm_spec["src_spec"]
tgt_sharding_spec = comm_action.comm_spec['tgt_spec'] tgt_sharding_spec = comm_action.comm_spec["tgt_spec"]
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec) rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst return rst
...@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]): ...@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
node_to_index_dict = {} node_to_index_dict = {}
index = 0 index = 0
for node in nodes: for node in nodes:
if node.target == 'sharding_spec_convert_dict': if node.target == "sharding_spec_convert_dict":
input_dict_node = node input_dict_node = node
continue continue
if node.target == 'origin_node_sharding_spec_dict': if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node origin_dict_node = node
continue continue
if node.target == 'comm_actions_dict': if node.target == "comm_actions_dict":
comm_actions_dict_node = node comm_actions_dict_node = node
continue continue
if not hasattr(node, 'best_strategy'): if not hasattr(node, "best_strategy"):
continue continue
node_to_index_dict[node] = index node_to_index_dict[node] = index
index += 1 index += 1
...@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): ...@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes) input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes: for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output': if not hasattr(node, "best_strategy") or node.op == "output":
continue continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes): for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if isinstance(node.sharding_spec, (list, tuple)): if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance( assert isinstance(
node.target_sharding_specs, node.target_sharding_specs, (list, tuple)
(list, ), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
total_difference = 0 total_difference = 0
for sharding_spec, target_sharding_spec in zip(node.sharding_spec, for sharding_spec, target_sharding_spec in zip(
node.target_sharding_specs[user_node_index]): node.sharding_spec, node.target_sharding_specs[user_node_index]
):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec) total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0: if total_difference == 0:
continue continue
with mod_graph.inserting_before(user_node): with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function', shape_consistency_node = mod_graph.create_node(
runtime_apply_for_iterable_object, "call_function",
args=(node, origin_dict_node, input_dict_node, runtime_apply_for_iterable_object,
node_to_index_dict[node], user_node_index)) args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
)
else: else:
assert isinstance(node.sharding_spec, assert isinstance(
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.' node.sharding_spec, ShardingSpec
), "node.sharding_spec should be type of ShardingSpec, tuple or list."
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0: if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue continue
with mod_graph.inserting_before(user_node): with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function', shape_consistency_node = mod_graph.create_node(
runtime_apply, "call_function",
args=(node, origin_dict_node, input_dict_node, runtime_apply,
node_to_index_dict[node], user_node_index)) args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
if hasattr(user_node.meta['info'], 'activation_checkpoint'): )
MetaInfo(shape_consistency_node, if hasattr(user_node.meta["info"], "activation_checkpoint"):
mod_dir=user_node.meta['info'].mod_dir, MetaInfo(
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint)) shape_consistency_node,
mod_dir=user_node.meta["info"].mod_dir,
activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint),
)
new_args = list(user_node.args) new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs) new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node # the origin node may be a positional argument or key word argument of user node
...@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): ...@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes) _, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes: for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output': if not hasattr(node, "best_strategy") or node.op == "output":
continue continue
comm_actions = node.best_strategy.communication_actions comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items(): for op_data, comm_action in comm_actions.items():
if comm_action.comm_type == CommType.HOOK: if comm_action.comm_type == CommType.HOOK:
continue continue
if comm_action.comm_type == CommType.BEFORE: if comm_action.comm_type == CommType.BEFORE:
...@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): ...@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
else: else:
comm_object = node.args[comm_action.arg_index] comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node): with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function', comm_spec_apply_node = mod_graph.create_node(
runtime_comm_spec_apply, "call_function",
args=(comm_object, comm_actions_dict_node, runtime_comm_spec_apply,
node_to_index_dict[node], op_data.name)) args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
)
# the origin node may be a positional argument or key word argument of user node # the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None: if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node # substitute the origin node with comm_spec_apply_node
...@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): ...@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
elif comm_action.comm_type == CommType.AFTER: elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function', comm_spec_apply_node = mod_graph.create_node(
runtime_comm_spec_apply, "call_function",
args=(node, comm_actions_dict_node, runtime_comm_spec_apply,
node_to_index_dict[node], op_data.name)) args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
)
user_list = list(node.users.keys()) user_list = list(node.users.keys())
for user in user_list: for user in user_list:
if user == comm_spec_apply_node: if user == comm_spec_apply_node:
...@@ -211,15 +212,17 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): ...@@ -211,15 +212,17 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node # substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs user.kwargs = new_kwargs
if hasattr(node.meta['info'], 'activation_checkpoint'): if hasattr(node.meta["info"], "activation_checkpoint"):
MetaInfo(comm_spec_apply_node, MetaInfo(
mod_dir=node.meta['info'].mod_dir, comm_spec_apply_node,
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) mod_dir=node.meta["info"].mod_dir,
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
)
return gm return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule): def _act_annotation_pass(gm: torch.fx.GraphModule):
""" """
This pass is used to add the act annotation to the new inserted nodes. This pass is used to add the act annotation to the new inserted nodes.
""" """
...@@ -227,21 +230,21 @@ def _act_annotataion_pass(gm: torch.fx.GraphModule): ...@@ -227,21 +230,21 @@ def _act_annotataion_pass(gm: torch.fx.GraphModule):
nodes = tuple(mod_graph.nodes) nodes = tuple(mod_graph.nodes)
for node in nodes: for node in nodes:
if not hasattr(node.meta, 'activation_checkpoint'): if not hasattr(node.meta, "activation_checkpoint"):
from .runtime_preparation_pass import size_processing pass
user_act_annotation = -1 user_act_annotation = -1
input_act_annotation = -1 input_act_annotation = -1
for user_node in node.users.keys(): for user_node in node.users.keys():
if 'activation_checkpoint' in user_node.meta: if "activation_checkpoint" in user_node.meta:
user_act_annotation = user_node.meta['activation_checkpoint'] user_act_annotation = user_node.meta["activation_checkpoint"]
break break
for input_node in node._input_nodes.keys(): for input_node in node._input_nodes.keys():
if 'activation_checkpoint' in input_node.meta: if "activation_checkpoint" in input_node.meta:
input_act_annotation = input_node.meta['activation_checkpoint'] input_act_annotation = input_node.meta["activation_checkpoint"]
break break
if user_act_annotation == input_act_annotation and user_act_annotation != -1: if user_act_annotation == input_act_annotation and user_act_annotation != -1:
node.meta['activation_checkpoint'] = user_act_annotation node.meta["activation_checkpoint"] = user_act_annotation
return gm return gm
......
import operator import operator
from copy import deepcopy
from typing import Dict, List, Union from typing import Dict, List, Union
import torch import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
CommAction,
CommType,
OperationDataType,
ShardingStrategy,
)
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce from colossalai.tensor.comm_spec import _all_reduce
...@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS ...@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS
shape_consistency_manager = ShapeConsistencyManager() shape_consistency_manager = ShapeConsistencyManager()
def size_processing(size: Union[int, torch.Size], def size_processing(
dim_partition_dict: Dict[int, List[int]], size: Union[int, torch.Size],
device_mesh_info: Dict[int, int], dim_partition_dict: Dict[int, List[int]],
target_dim: int = None, device_mesh_info: Dict[int, int],
node_name: str = None): target_dim: int = None,
node_name: str = None,
):
""" """
This method will be invoked during runtime to convert size node value depending on distributed information. This method will be invoked during runtime to convert size node value depending on distributed information.
""" """
...@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size], ...@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
return size return size
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], def solution_annotation_pass(
strategies_constructor: StrategiesConstructor): gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor
):
""" """
This method is used to stick the solution strategy to the nodes and add the information This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes. required in runtime into graph as placeholder nodes.
...@@ -70,14 +66,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], ...@@ -70,14 +66,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector strategies_vector = node.strategies_vector
# stick the solution strategy to the corresponding node # stick the solution strategy to the corresponding node
setattr(node, 'best_strategy', strategies_vector[strategy_index]) setattr(node, "best_strategy", strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node)) str(node)
)
# attach the corresponding metainfo if node has the attribute `strategies_info` # attach the corresponding metainfo if node has the attribute `strategies_info`
if hasattr(node, 'strategies_info'): if hasattr(node, "strategies_info"):
setattr(node, 'best_strategy_info', node.strategies_info[strategy_index]) setattr(node, "best_strategy_info", node.strategies_info[strategy_index])
# the dict to get input sharding specs of user node # the dict to get input sharding specs of user node
sharding_spec_convert_dict = {} sharding_spec_convert_dict = {}
...@@ -92,15 +89,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], ...@@ -92,15 +89,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec) target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs) setattr(node, "target_sharding_specs", target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it # the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node. # to the same strategy of the user node.
if node.op == 'get_attr': if node.op == "get_attr":
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version."
target_node = node.strategies_vector.successor_nodes[0] target_node = node.strategies_vector.successor_nodes[0]
node_name = str(node) node_name = str(node)
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP: if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP:
node_name = str(target_node) node_name = str(target_node)
target_node = target_node.strategies_vector.successor_nodes[0] target_node = target_node.strategies_vector.successor_nodes[0]
user_strategy = target_node.best_strategy user_strategy = target_node.best_strategy
...@@ -122,11 +119,11 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], ...@@ -122,11 +119,11 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
# add above dicts into graph # add above dicts into graph
for node in nodes: for node in nodes:
if node.op != 'placeholder': if node.op != "placeholder":
with mod_graph.inserting_before(node): with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict') comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict")
break break
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
...@@ -144,11 +141,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh ...@@ -144,11 +141,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# DeviceMesh information instructs the scaling of the size value # DeviceMesh information instructs the scaling of the size value
device_mesh_info = {} device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape): for dim, dim_size in enumerate(device_mesh.shape):
device_mesh_info[dim] = dim_size device_mesh_info[dim] = dim_size
def _extract_target_dim(node): def _extract_target_dim(node):
''' """
A helper function to extract the target dimension from size node. A helper function to extract the target dimension from size node.
There are two usages of torch.Tensor.size: There are two usages of torch.Tensor.size:
1. tensor.size() 1. tensor.size()
...@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh ...@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size. If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None. Otherwise, the output will be in type of torch.Size and this function will return None.
''' """
target_dim = None target_dim = None
if len(node.args) > 1: if len(node.args) > 1:
target_dim = node.args[1] target_dim = node.args[1]
...@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh ...@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
return target_dim return target_dim
def _post_processing(node, size_processing_node): def _post_processing(node, size_processing_node):
''' """
This function is used to process the dependency between the size node and its users after This function is used to process the dependency between the size node and its users after
inserting the size_process_node. inserting the size_process_node.
''' """
# store original node and processing node pair in node_pairs dictioanry # store original node and processing node pair in node_pairs dictionary
# It will be used to replace the original node with processing node in slice object # It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data size_processing_node._meta_data = node._meta_data
if hasattr(node.meta['info'], 'activation_checkpoint'): if hasattr(node.meta["info"], "activation_checkpoint"):
MetaInfo(size_processing_node, MetaInfo(
mod_dir=node.meta['info'].mod_dir, size_processing_node,
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) mod_dir=node.meta["info"].mod_dir,
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
)
user_list = list(node.users.keys()) user_list = list(node.users.keys())
for user in user_list: for user in user_list:
...@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh ...@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
user.kwargs = new_kwargs user.kwargs = new_kwargs
def _update_slice_object_args(slice_object): def _update_slice_object_args(slice_object):
''' """
This function is used to update the slice object argument list. This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with If the slice object contains the Node argument, then the size node will be replaced with
''' """
if isinstance(slice_object, slice): if isinstance(slice_object, slice):
start = slice_object.start start = slice_object.start
stop = slice_object.stop stop = slice_object.stop
...@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh ...@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}") raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
for node in nodes: for node in nodes:
if node.op == "call_method" and node.target == "size":
if node.op == 'call_method' and node.target == 'size':
# extract useful information from size node # extract useful information from size node
# dim_partition_dict will instruct the size value on which # dim_partition_dict will instruct the size value on which
# dimension should be enlarged. # dimension should be enlarged.
...@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh ...@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# insert size_processing node # insert size_processing node
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
size_processing_node = mod_graph.create_node('call_function', size_processing_node = mod_graph.create_node(
size_processing, "call_function",
args=(node, dim_partition_dict, device_mesh_info, size_processing,
target_dim, node.name)) args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name),
)
_post_processing(node, size_processing_node) _post_processing(node, size_processing_node)
if node.op == 'call_function' and node.target == operator.getitem: if node.op == "call_function" and node.target == operator.getitem:
getitem_index = node.args[1] getitem_index = node.args[1]
# slice object is quite special in torch.fx graph, # slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int, # On one side, we treat slice object same as type of int,
...@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) ...@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
nodes = tuple(mod_graph.nodes) nodes = tuple(mod_graph.nodes)
def _extract_info_from_sharding_spec(sharding_spec): def _extract_info_from_sharding_spec(sharding_spec):
''' """
This function is used to extract the dim_partition_dict and device_mesh from This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec. sharding spec instance or a list of sharding spec.
''' """
if isinstance(sharding_spec, ShardingSpec): if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh return dim_partition_dict, device_mesh
if sharding_spec is None: if sharding_spec is None:
return None, None return None, None
assert isinstance(sharding_spec, assert isinstance(
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' sharding_spec, (tuple, list)
), "sharding_spec should be type of ShardingSpec, tuple, list or None"
device_mesh = sharding_spec[0].device_mesh device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = [] dim_partition_dict = []
...@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) ...@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
else: else:
new_args.append(arg) new_args.append(arg)
else: else:
assert isinstance(arg, assert isinstance(
(int, tuple, list)), 'The argument in view node should be either type of Node or int.' arg, (int, tuple, list)
), "The argument in view node should be either type of Node or int."
if isinstance(arg, (tuple, list)): if isinstance(arg, (tuple, list)):
new_args.extend(arg) new_args.extend(arg)
else: else:
...@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) ...@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
new_args = _process_node_arguments(node) new_args = _process_node_arguments(node)
if node.op == 'call_method': if node.op == "call_method":
args_to_process = list(new_args[1:]) args_to_process = list(new_args[1:])
else: else:
args_to_process = list(new_args) args_to_process = list(new_args)
...@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) ...@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
args_to_process = tuple(args_to_process) args_to_process = tuple(args_to_process)
if node.op == 'call_method': if node.op == "call_method":
new_args = (new_args[0],) + args_to_process new_args = (new_args[0],) + args_to_process
else: else:
new_args = args_to_process new_args = args_to_process
...@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) ...@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
node.args = new_args node.args = new_args
def _filter_node_with_shape_args(node): def _filter_node_with_shape_args(node):
if node.op == 'call_method': if node.op == "call_method":
target = getattr(node.args[0]._meta_data.__class__, node.target) target = getattr(node.args[0]._meta_data.__class__, node.target)
elif node.op == 'call_function': elif node.op == "call_function":
target = node.target target = node.target
else: else:
target = None target = None
...@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) ...@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
for node in nodes: for node in nodes:
# skip the placeholder node added in _solution_annotation pass # skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'): if not hasattr(node, "sharding_spec"):
continue continue
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec) output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
...@@ -388,19 +388,25 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -388,19 +388,25 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
""" """
mod_graph = gm.graph mod_graph = gm.graph
nodes = tuple(mod_graph.nodes) nodes = tuple(mod_graph.nodes)
# This stream is created for overlaping the communication and computation. # This stream is created for overlapping the communication and computation.
reduction_stream = torch.cuda.Stream() reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param, name=None): def _add_hook_for_grad_communication(node, param, name=None):
comm_actions = node.best_strategy.communication_actions comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action, name): def _filter_param_to_hook(node, op_data, comm_action, name):
if (
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK: node.op == "call_module"
and op_data.type == OperationDataType.PARAM
and op_data.name == name
and comm_action.comm_type == CommType.HOOK
):
return True return True
if node.op == 'get_attr' and isinstance( if (
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: node.op == "get_attr"
and isinstance(node._meta_data, torch.nn.parameter.Parameter)
and comm_action.comm_type == CommType.HOOK
):
return True return True
return False return False
...@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
if _filter_param_to_hook(node, operation_data, comm_action, name=name): if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap): def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad): def hook_fn(grad):
if overlap: if overlap:
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
...@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of parameters # apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}: if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec) setattr(param, "sharding_spec", origin_sharding_spec)
# TODO: build a ColoParameter class to manager the distributed parameters # TODO: build a ColoParameter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training # we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph. # loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter( param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, shape_consistency_manager.apply_for_autoparallel_runtime(
target_sharding_spec).detach().clone()) param.data, param.sharding_spec, target_sharding_spec
)
.detach()
.clone()
)
return param return param
for node in nodes: for node in nodes:
if node.op == 'call_module': if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target) target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters. # TODO: we need to do more actions to take care of the shared parameters.
if hasattr(target_module, 'processed') and target_module.processed: if hasattr(target_module, "processed") and target_module.processed:
continue continue
setattr(target_module, 'processed', True) setattr(target_module, "processed", True)
for name, param in target_module.named_parameters(): for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
param = _shard_param(param, target_sharding_spec) param = _shard_param(param, target_sharding_spec)
...@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of buffers # apply the sharding spec of buffers
for name, buffer in target_module.named_buffers(): for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {}) origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
setattr(buffer, 'sharding_spec', origin_sharding_spec) setattr(buffer, "sharding_spec", origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec) buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
sharded_buffer_dict[name] = buffer_sharded sharded_buffer_dict[name] = buffer_sharded
...@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
for name, buffer_sharded in sharded_buffer_dict.items(): for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone()) setattr(target_module, name, buffer_sharded.detach().clone())
if node.op == 'get_attr': if node.op == "get_attr":
root = node.graph.owning_module root = node.graph.owning_module
atoms = node.target.split(".") atoms = node.target.split(".")
attr_len = len(atoms) attr_len = len(atoms)
...@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule): ...@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
""" """
replace the origin kernel into kernel with implicit communication inside. replace the origin kernel into kernel with implicit communication inside.
""" """
pass
def runtime_preparation_pass(gm: torch.fx.GraphModule, def runtime_preparation_pass(
solution: List[int], gm: torch.fx.GraphModule,
device_mesh: DeviceMesh, solution: List[int],
strategies_constructor: StrategiesConstructor, device_mesh: DeviceMesh,
overlap=False): strategies_constructor: StrategiesConstructor,
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass( overlap=False,
gm, solution, strategies_constructor) ):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
gm, solution, strategies_constructor
)
gm = size_value_converting_pass(gm, device_mesh) gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh) gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment