Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
......@@ -3,7 +3,7 @@ from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
......
from typing import Callable, Dict, List, Tuple, Union
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info']
__all__ = ["batchnormnd_meta_info", "layernorm_meta_info"]
@meta_register.register(torch.nn.BatchNorm1d)
......@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args = [
output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
output_tensor,
output_tensor,
weight_tensor,
mean_tensor,
var_tensor,
mean_tensor,
var_tensor,
1e-5,
num_batch,
]
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
......@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
[input_tensor, output_tensor, mean_tensor, var_tensor]),
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
)
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
......@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
running_mean = torch.rand(input_tensor.shape[0], 1, device='meta')
running_var = torch.rand(input_tensor.shape[0], 1, device='meta')
running_mean = torch.rand(input_tensor.shape[0], 1, device="meta")
running_var = torch.rand(input_tensor.shape[0], 1, device="meta")
# construct args
fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor]
......@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
[input_tensor, output_tensor, weight_tensor, bias_tensor]),
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=compute_size_in_bytes([running_mean, running_var]))
buffer=compute_size_in_bytes([running_mean, running_var]),
)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([running_mean, running_var]),
buffer=compute_size_in_bytes([running_mean, running_var]))
buffer=compute_size_in_bytes([running_mean, running_var]),
)
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
......@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
......@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
temp=compute_size_in_bytes(index_matrix))
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
temp=compute_size_in_bytes(index_matrix),
)
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
......@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [torch.zeros_like(index_matrix, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
......@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
......@@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
parameter=0,
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
buffer=0)
buffer=0,
)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
......@@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# register torch.Tensor related metainfo
# (0, 0)
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
torch.arange])(tensor_related_metainfo(0, 0))
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])(
tensor_related_metainfo(0, 0)
)
# (1, 0)
meta_register.register([
torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
torch.Tensor.split, torch.split, torch.Tensor.view
])(tensor_related_metainfo(1, 0))
meta_register.register(
[
torch.Tensor.flatten,
torch.flatten,
torch.Tensor.transpose,
torch.transpose,
torch.Tensor.permute,
torch.permute,
torch.Tensor.split,
torch.split,
torch.Tensor.view,
]
)(tensor_related_metainfo(1, 0))
# (1, 1)
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))
......@@ -4,7 +4,7 @@ import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
......@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
bwd_mem_cost = MemoryCost(
activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
parameter=0,
temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
activation_size([x_tensor, y_tensor]),
buffer=0)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
temp=activation_size([output_tensor]) * 3
+ activation_size([condition_tensor])
- activation_size([x_tensor, y_tensor]),
buffer=0,
)
total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
......
__all__ = ['Registry']
__all__ = ["Registry"]
class Registry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
......@@ -21,7 +19,7 @@ class Registry:
return wrapper
def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry'
assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
......@@ -29,4 +27,4 @@ class Registry:
return source in self.store
meta_register = Registry('meta')
meta_register = Registry("meta")
......@@ -2,20 +2,13 @@ from typing import Callable, List
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['ShardMetaInfo']
__all__ = ["ShardMetaInfo"]
class ShardMetaInfo:
......@@ -76,10 +69,12 @@ class ShardMetaInfo:
"""
if isinstance(sharding_spec, ShardingSpec):
op_data = OperationData(name=operation_data.name,
op_data = OperationData(
name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
logical_shape=operation_data.logical_shape,
)
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
......@@ -97,8 +92,9 @@ class ShardMetaInfo:
"""
Compute meta info based on sharding strategy and the given target function.
"""
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
f"Meta info for {self._target} is not registered."
assert meta_register.has(self._target.__class__) or meta_register.has(
self._target
), f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
......@@ -117,11 +113,11 @@ class ShardMetaInfo:
# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
kwargs = {"inplace": self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
kwargs = {"inplace": True}
else:
kwargs = {'inplace': False}
kwargs = {"inplace": False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
......
from typing import Dict, Tuple
from enum import Enum
from typing import Dict, Tuple
import torch
from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .region import Region
from .region_manager import RegionManager
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class AMPOptimizer(ColossalaiOptimizer):
class AMPOptimizer(OptimizerWrapper):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
......@@ -36,7 +37,8 @@ class AMPOptimizer(ColossalaiOptimizer):
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""
def __init__(self,
def __init__(
self,
optimizer: Optimizer,
module: BaseOffloadModule,
initial_scale: float = 2**16,
......@@ -47,8 +49,8 @@ class AMPOptimizer(ColossalaiOptimizer):
min_scale: float = 1,
max_scale: float = 2**32,
clipping_norm: float = 0.0,
norm_type: float = 2.0):
norm_type: float = 2.0,
):
super().__init__(optimizer)
self.module = module
......@@ -68,19 +70,21 @@ class AMPOptimizer(ColossalaiOptimizer):
self.__init__optimizer()
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
self.grad_scaler = DynamicGradScaler(
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
max_scale=max_scale,
)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._logger = get_dist_logger()
def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group['params']:
for fake_param in group["params"]:
region = self.param_to_region[fake_param]
begin, end = self.param_to_range[fake_param]
......@@ -91,7 +95,7 @@ class AMPOptimizer(ColossalaiOptimizer):
def _update_fp16_params(self):
none_tensor = torch.empty([0])
for group in self.param_groups:
for fake_param in group['params']:
for fake_param in group["params"]:
assert fake_param.grad is None
fake_param.data = none_tensor
self.param_to_region[fake_param].cpu_grad = None
......@@ -131,7 +135,7 @@ class AMPOptimizer(ColossalaiOptimizer):
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self._logger.info(f"Found overflow. Skip step")
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
......@@ -155,11 +159,10 @@ class AMPOptimizer(ColossalaiOptimizer):
self.module.backward(loss)
def __init__optimizer(self):
for group in self.optim.param_groups:
fake_params_list = list()
for param in group['params']:
for param in group["params"]:
region = self.region_manager.get_region(param)
fake_param = torch.nn.Parameter(torch.empty([0]))
self.param_to_range[fake_param] = region.param_to_range[param]
......@@ -170,7 +173,7 @@ class AMPOptimizer(ColossalaiOptimizer):
if param in self.optim.state:
self.optim.state[fake_param] = self.optim.state.pop(param)
group['params'] = fake_params_list
group["params"] = fake_params_list
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
......
......@@ -4,7 +4,7 @@ from typing import Optional, Set
import torch
import torch.nn as nn
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from .region_manager import RegionManager
......@@ -22,7 +22,6 @@ class BaseOffloadModule:
"""
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
self.model = model
self.region_manager = region_manager
self.grad_hook_list = []
......@@ -91,17 +90,16 @@ class BaseOffloadModule:
def parameters(self, recurse: bool = True):
return self.model.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True):
def named_parameters(self, prefix: str = "", recurse: bool = True):
return self.model.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True):
def named_buffers(self, prefix: str = "", recurse: bool = True):
return self.model.named_buffers(prefix, recurse)
def named_children(self):
return self.model.named_children()
def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
def named_modules(
self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
return self.model.named_modules(memo, prefix, remove_duplicate)
......@@ -14,11 +14,9 @@ from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
def memory_optimize(model: torch.nn.Module,
inps: Dict[str, torch.Tensor],
memory_budget: float = -1.0,
solver_name: str = 'asyn'):
def memory_optimize(
model: torch.nn.Module, inps: Dict[str, torch.Tensor], memory_budget: float = -1.0, solver_name: str = "asyn"
):
model = model.cpu().half()
tracer = ColoTracer()
assert is_compatible_with_meta()
......@@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module,
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
)
if solver_name == 'syn':
if solver_name == "syn":
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
elif solver_name == 'asyn':
elif solver_name == "asyn":
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
else:
raise TypeError(f"Unknown solver name {solver_name}!")
gm.recompile()
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == "syn")
return optimized_model
......@@ -55,13 +55,13 @@ class Region:
Map the parameters in the region to a contiguous memory space.
"""
self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device="cuda")
offset = 0
for param in self.fp16_params:
param.data = param.data.cuda()
p_num = param.data.numel()
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
self.fp16_data[offset : offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset : offset + p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num)
offset += p_num
......@@ -83,7 +83,7 @@ class Region:
self.temp_fp32_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
alloc_storage(self.fp16_data)
self.fp16_data[:self.param_num].copy_(self.temp_fp32_data)
self.fp16_data[: self.param_num].copy_(self.temp_fp32_data)
self.fp16_data.record_stream(torch.cuda.current_stream())
self.__update_params_ptr()
......@@ -94,7 +94,7 @@ class Region:
"""
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True)
self.cpu_grad.copy_(self.fp16_data[: self.param_num], non_blocking=True)
self.fp16_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
self.free_cuda_data()
......
from typing import List, Any, Dict, Tuple
from typing import Any, Dict, List, Tuple
import torch
from torch.fx import Graph, Node
from .region import Region
from .solver import SolverFactory
from .training_simulator import TrainingSimulator
from .region import Region
from .util import NodeInfo
......@@ -19,14 +20,9 @@ class RegionManager:
cnode (List[str], optional): Common node List, should be the subset of input.
"""
def __init__(self,
graph: Graph,
solver_name: str = 'asyn',
memory_budget: float = -1.0,
cnode: List[str] = None):
def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.cnode = cnode
......@@ -39,7 +35,7 @@ class RegionManager:
self.memory_budget = memory_budget
self.solver_name = solver_name
self.require_pool: bool = solver_name == 'asyn'
self.require_pool: bool = solver_name == "asyn"
self.reg_to_block: Dict[int, int] = dict()
......@@ -61,22 +57,19 @@ class RegionManager:
self._post_process(solver.best_ts)
def _pre_process(self):
init_region_list = self._linearize_graph()
if len(self.shared_region_pairs) > 1:
raise NotImplementedError(
'The current version only considers at most one pair of parameter sharing.')
raise NotImplementedError("The current version only considers at most one pair of parameter sharing.")
elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0]
assert shared_regs[0].shared_rid == shared_regs[1].r_id \
and shared_regs[1].shared_rid == shared_regs[0].r_id
assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id
fst_id = shared_regs[0].r_id
lst_id = shared_regs[1].r_id
regs_left_out = init_region_list[:fst_id + 1]
regs_left_out = init_region_list[: fst_id + 1]
regs_right_out = init_region_list[lst_id:]
hold_regs = init_region_list[fst_id + 1:lst_id]
hold_regs = init_region_list[fst_id + 1 : lst_id]
else:
regs_left_out = []
regs_right_out = []
......@@ -122,12 +115,9 @@ class RegionManager:
it may not find a suitable region placement strategy for the given execution flow.
"""
reg_flow = torch.cat(
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
mem_block_num = torch.max(
torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(
ts.fwd_reg_flow, ts.bwd_reg_flow)
reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow)
block_to_regs = {}
for block_idx in range(mem_block_num):
......@@ -135,8 +125,7 @@ class RegionManager:
for reg in self.region_list:
if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id]
cur_reg_coexists = torch.sum(
coexist_matrix[cur_reg_appears], dim=0).bool()
cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id)
......@@ -145,9 +134,12 @@ class RegionManager:
if reg.r_id not in self.reg_to_block:
raise NotImplementedError(
f'can not find a block from the memory pool to store parameters of the region')
self.memory_pool = torch.chunk(torch.zeros(int(
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
f"can not find a block from the memory pool to store parameters of the region"
)
self.memory_pool = torch.chunk(
torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"),
chunks=int(mem_block_num),
)
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
"""
......@@ -178,10 +170,9 @@ class RegionManager:
return region_list
def _search_block_size(self,
region_list: List[Region],
search_interval_byte: int = 1024,
search_range_byte: int = 128 * 1024 ** 2) -> int:
def _search_block_size(
self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2
) -> int:
"""
Search for a suitable memory block size.
......@@ -208,11 +199,10 @@ class RegionManager:
acc_wasted += blk_size - left
return acc_wasted
param_size_list = [
region.param_size for region in region_list if region.r_id == region.shared_rid]
param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid]
start_size = max(param_size_list)
min_mem_waste = float('+inf')
min_mem_waste = float("+inf")
best_block_size = start_size
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
......@@ -229,7 +219,7 @@ class RegionManager:
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32)
for region in self.region_list:
pre_alloc_tensor = None
......@@ -244,8 +234,7 @@ class RegionManager:
region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
)
region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach()
torch.cuda.empty_cache()
......@@ -259,13 +248,14 @@ class RegionManager:
former_reg, latter_reg = self.shared_region_pairs[0]
assert latter_reg.param_num >= former_reg.param_num
embedding_node = former_reg.nodes[-1]
assert embedding_node.op == 'call_module' and isinstance(
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
assert embedding_node.op == "call_module" and isinstance(
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding
)
if latter_reg.param_num > former_reg.param_num:
for idx, n in enumerate(latter_reg.nodes):
if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
torch.nn.Linear)) or \
(n.op == 'call_function' and n.target is torch.nn.functional.linear):
if (
n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear)
) or (n.op == "call_function" and n.target is torch.nn.functional.linear):
cut_node_idx = idx + 1
break
assert len(latter_reg.fp16_params) == 2
......@@ -273,7 +263,7 @@ class RegionManager:
for p in new_reg.fp16_params:
self.param_region_map[p] = new_reg
self.region_list.insert(new_reg.r_id, new_reg)
for reg in self.region_list[new_reg.r_id + 1:]:
for reg in self.region_list[new_reg.r_id + 1 :]:
reg.r_id += 1
latter_reg.shared_rid = former_reg.r_id
former_reg.shared_rid = latter_reg.r_id
......@@ -362,14 +352,12 @@ class RegionManager:
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
"""Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(
n.target), "inplace", False)
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
label = False
......@@ -385,21 +373,23 @@ class RegionManager:
elif n.op == "call_function":
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)
)
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
def _exception_node_handling():
# TODO meta info prop bug
if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
n.meta['fwd_out'] = []
if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2:
n.meta["fwd_out"] = []
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
assert (
next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
......@@ -428,8 +418,8 @@ class RegionManager:
ns = []
border_n_idx = region.nodes.index(act_n)
if border_n_idx < len(region.nodes):
ns = region.nodes[border_n_idx + 1:]
region.nodes = region.nodes[:border_n_idx + 1]
ns = region.nodes[border_n_idx + 1 :]
region.nodes = region.nodes[: border_n_idx + 1]
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
......@@ -448,19 +438,21 @@ class RegionManager:
region = Region(r_id=region_id)
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
if len(n.all_input_nodes) == len(
[node for node in n.all_input_nodes if node.name in self.cnode]
) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len(
[user for user in n.users if user.op != "output"])
deps[n] = len([user for user in n.users if user.op != "output"])
# propagate param node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
]) or n.op == "get_attr":
if (
len(n.all_input_nodes)
== len([node for node in n.all_input_nodes if node.name in self.only_param_ops])
or n.op == "get_attr"
):
self.only_param_ops.append(n.name)
param_op_deps[n] = len(
[user for user in n.users if user.op != "output"])
param_op_deps[n] = len([user for user in n.users if user.op != "output"])
# record last activation node
if _is_act(n._meta_data):
......@@ -472,19 +464,16 @@ class RegionManager:
return region_list
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
cur_n.node_info = NodeInfo(node_id)
if cur_n.op == 'call_module':
if cur_n.op == "call_module":
target = cur_n.target
submod = self.root_module.get_submodule(target)
for p in list(submod.parameters(recurse=False)):
if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[p], cur_reg))
self.shared_region_pairs.append((self.param_region_map[p], cur_reg))
else:
self.param_region_map[p] = cur_reg
......@@ -499,12 +488,10 @@ class RegionManager:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.Parameter):
if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[attr_itr], cur_reg))
self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg))
else:
self.param_region_map[attr_itr] = cur_reg
......
......@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
d2h_rid = fwd_info.get('d2h_rid', None)
d2h_rid = fwd_info.get("d2h_rid", None)
if d2h_rid is not None:
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region)
free_region.free_cuda_data()
h2d_rid = fwd_info.get('h2d_rid', None)
h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region)
......@@ -38,8 +38,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
......@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
sync_rid = fwd_info.get('sync_rid', None)
sync_rid = fwd_info.get("sync_rid", None)
if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
h2d_rid = fwd_info.get('h2d_rid', None)
h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
......@@ -87,8 +86,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
sync_rid = ctx.bwd_info.get('sync_rid', None)
sync_rid = ctx.bwd_info.get("sync_rid", None)
if sync_rid is not None:
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region)
......@@ -98,7 +96,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
else:
wait_region.move_param_to_cuda()
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
......@@ -114,7 +112,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
"""
Convert Upload and Offload operation into runtime action.
Argument:
......@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass.
'''
"""
with torch._C.DisableTorchFunction():
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
"""
Convert Prefetch and Offload operation into runtime action.
Argument:
......@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass.
'''
"""
with torch._C.DisableTorchFunction():
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
......@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
# forward upload
fwd_info = {}
if requires_upload_p_in_fwd(region_list[region.shared_rid]):
fwd_info['h2d_rid'] = region.r_id
fwd_info["h2d_rid"] = region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1
fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward upload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id
bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
new_node = mod_graph.create_node(
"call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)
)
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
......@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
first_region_with_p = [region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node):
upload_apply_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, {}))
upload_apply_node = mod_graph.create_node(
"call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {})
)
replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node
......@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# forward prefetch
fwd_info = {}
if region.param_size:
fwd_info['sync_rid'] = region.r_id
fwd_info["sync_rid"] = region.r_id
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
fwd_info["h2d_rid"] = fwd_prefetch_region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1
fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward prefetch
if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['sync_rid'] = r_idx - 1
bwd_info["sync_rid"] = r_idx - 1
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function',
new_node = mod_graph.create_node(
"call_function",
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
args=(last_inp_node, fwd_info, bwd_info),
)
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
if region.bwd_prefetch_region:
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, {}, bwd_info))
new_node = mod_graph.create_node(
"call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info)
)
replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular()
return gm
import time
from typing import List, Dict, Type
from abc import ABC, abstractmethod
from typing import Dict, List, Type
NOT_NVML = False
try:
......@@ -10,10 +10,11 @@ except:
import torch
from torch.fx.node import Node
from colossalai.utils.cuda import get_current_device
from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
from .util import NodeInfo, NvDevicePower
......@@ -49,19 +50,14 @@ class Solver(ABC):
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
"""
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0,
error_factor: float = 0.95) -> None:
def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None:
self.region_list = region_list
self.error_factor: float = error_factor
if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor
else:
self.memory_budget = torch.cuda.get_device_properties(
get_current_device()).total_memory * self.error_factor
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power()
......@@ -94,7 +90,7 @@ class Solver(ABC):
if extra_cost == 0:
# means data transfer overhead can be completely overlapped
return (float('inf'), total_mem_saving, peak_mem_saving)
return (float("inf"), total_mem_saving, peak_mem_saving)
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
......@@ -122,9 +118,7 @@ class Solver(ABC):
self.best_ts = best_ts
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
def _update_node_mem_info(self,
fwd_mem_info: Dict[Node, float],
bwd_mem_info: Dict[Node, float]):
def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]):
"""
Update the runtime memory information of the node.
......@@ -134,12 +128,10 @@ class Solver(ABC):
"""
for node, mem in fwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance(
node.node_info, NodeInfo)
assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_fwd_mem = mem
for node, mem in bwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance(
node.node_info, NodeInfo)
assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_bwd_mem = mem
def _extract_computing_power(self):
......@@ -159,12 +151,12 @@ class Solver(ABC):
return NvDevicePower.RTX3080_FP16 * units
elif device_name.__contains__("RTX 3090"):
return NvDevicePower.RTX3090_FP16 * units
elif device_name.__contains__('V100'):
elif device_name.__contains__("V100"):
return NvDevicePower.V100_FP16 * units
elif device_name.__contains__("A100"):
return NvDevicePower.A100_FP16 * units
else:
raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
raise TypeError(f"Unknown NVIDIA GPU device name {device_name}")
def _profile_bandwidth(self):
"""
......@@ -172,9 +164,9 @@ class Solver(ABC):
using data volumes ranging from 1KB to 1GB.
"""
print('profiling bandwidth ......')
print("profiling bandwidth ......")
link_to_bandwidth = {}
links = ['h2d', 'd2h']
links = ["h2d", "d2h"]
for link in links:
t_size = 1024
......@@ -182,24 +174,22 @@ class Solver(ABC):
# from 1KB to 1GB
for i in range(21):
if link == 'h2d':
src_tensor = torch.ones(
int(t_size), dtype=torch.int8, pin_memory=True)
dst_tensor = torch.ones(
(int(t_size)), dtype=torch.int8, device='cuda')
elif link == 'd2h':
src_tensor = torch.ones(
int(t_size), dtype=torch.int8, device='cuda')
dst_tensor = torch.ones(
(int(t_size)), dtype=torch.int8, pin_memory=True)
if link == "h2d":
src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True)
dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda")
elif link == "d2h":
src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda")
dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True)
def func():
dst_tensor.copy_(src_tensor)
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
print(f'size: {t_size / 1024 ** 2:.3f} MB, '
f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
print(
f"size: {t_size / 1024 ** 2:.3f} MB, "
f"{src_tensor.device.type}-to-{dst_tensor.device.type} "
f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s"
)
t_size *= 2
......@@ -208,10 +198,7 @@ class Solver(ABC):
class SynGreedySolver(Solver):
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0) -> None:
def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None:
super().__init__(region_list, memory_budget)
self.best_ts: SynTrainingSimulator = None
......@@ -258,7 +245,8 @@ class SynGreedySolver(Solver):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
)
def _call_solver_l2l(self):
"""
......@@ -270,7 +258,6 @@ class SynGreedySolver(Solver):
region.is_syn = True
def _try_to_offload(self, offload_region: Region):
# record previous information
orig_need_offload = offload_region.need_offload
assert not orig_need_offload
......@@ -297,23 +284,17 @@ class SynGreedySolver(Solver):
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
extra_comm_cost = 2.0 * \
ts._get_communication_overhead('h2d', offload_region.param_size)
extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size)
# the shared region needs to be moved twice
if offload_region.r_id < offload_region.shared_rid:
extra_comm_cost *= 2.0
profit = self._compute_offload_profit(
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class AsynGreedySolver(Solver):
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0,
search_window_size: int = 3):
def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3):
super().__init__(region_list, memory_budget)
self.search_window_size = search_window_size
......@@ -331,7 +312,7 @@ class AsynGreedySolver(Solver):
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
self._update_state(ts)
print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB")
def _call_solver(self):
"""
......@@ -358,18 +339,17 @@ class AsynGreedySolver(Solver):
best_pref_ts = None
# search when to prefetch the region offloaded
for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]:
if host_region.bwd_prefetch_region is not None:
continue
temp_ts, profit = self._try_to_offload(
host_region, region)
temp_ts, profit = self._try_to_offload(host_region, region)
if self._compare_profit(profit, max_prefetch_profit):
region_to_region_map[region.r_id] = host_region
max_prefetch_profit = profit
best_pref_ts = temp_ts
if profit[0] == float('inf'):
if profit[0] == float("inf"):
break
if self._compare_profit(max_prefetch_profit, max_offload_profit):
......@@ -392,7 +372,8 @@ class AsynGreedySolver(Solver):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
)
region_to_region_map.clear()
......@@ -452,7 +433,6 @@ class AsynGreedySolver(Solver):
peak_mem_saving = 0
while len(self.region_to_region_map) and peak_mem_saving <= 0:
max_profit = (0,)
best_ts = None
undo_host_region = None
......@@ -464,8 +444,7 @@ class AsynGreedySolver(Solver):
assert offload_region.need_offload
assert not offload_region.is_syn
ts, profit = self._try_convert_to_syn_upload(host_region,
offload_region)
ts, profit = self._try_convert_to_syn_upload(host_region, offload_region)
if self._compare_profit(profit, max_profit):
undo_host_region = host_region
......@@ -474,7 +453,7 @@ class AsynGreedySolver(Solver):
best_ts = ts
if best_ts is None:
raise NotImplementedError('repair error!')
raise NotImplementedError("repair error!")
assert not undo_offload_region.is_syn
undo_offload_region.is_syn = True
......@@ -500,17 +479,13 @@ class AsynGreedySolver(Solver):
ts.execute()
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
profit = self._compute_offload_profit(
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class SolverFactory:
solvers: Dict[str, Type[Solver]] = {
'syn': SynGreedySolver,
'asyn': AsynGreedySolver
}
solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver}
@staticmethod
def create(solver_name: str) -> Type[Solver]:
......
import bisect
from typing import List, Dict
from collections import OrderedDict
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, List
from torch.fx.node import Node
......@@ -26,10 +26,7 @@ class TrainingSimulator(ABC):
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
self.region_list = region_list
self.region_num = len(region_list)
......@@ -87,11 +84,7 @@ class TrainingSimulator(ABC):
class SynTrainingSimulator(TrainingSimulator):
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
def execute(self):
......@@ -115,8 +108,7 @@ class SynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \
calculate_fwd_out(node)
self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.fwd_node_mem[node] = self.runtime_mem
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
......@@ -141,18 +133,15 @@ class SynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
node.meta['bwd_mem_out']
self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
calculate_fwd_tmp(node))
self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
......@@ -160,12 +149,14 @@ class SynTrainingSimulator(TrainingSimulator):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out']
self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
raise ValueError(
f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!")
f"runtime memory computed less than 0, which is miscalculated!"
)
# release parameter and offload gradient in region
if region.r_id == region.shared_rid:
......@@ -177,23 +168,16 @@ class SynTrainingSimulator(TrainingSimulator):
class AsynTrainingSimulator(TrainingSimulator):
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
self.iter_end_time: int = 0
# the last computation execution period
self.last_comp: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last parameter prefetch execution period
self.last_h2d: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last gradient offload execution period
self.last_d2h: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the forward computation execution period of the region
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the forward parameter prefetch execution period of the region
......@@ -204,10 +188,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
self.bwd_reg_to_offl_waiting: OrderedDict[int,
ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_freed: OrderedDict[int,
ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the region buffer, which records regions that are offloaded but not released
self.reg_buffer_to_free: List[int] = []
......@@ -217,10 +199,8 @@ class AsynTrainingSimulator(TrainingSimulator):
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
self.fwd_reg_flow = torch.zeros(
(self.region_num, self.region_num)).bool()
self.bwd_reg_flow = torch.zeros(
(self.region_num, self.region_num)).bool()
self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
def execute(self):
"""
......@@ -232,7 +212,7 @@ class AsynTrainingSimulator(TrainingSimulator):
for reg in self.region_list:
if reg.param_size and reg.r_id < self.region_num - 1:
for nr in self.region_list[reg.r_id + 1:]:
for nr in self.region_list[reg.r_id + 1 :]:
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
reg.fwd_prefetch_region = nr
break
......@@ -249,8 +229,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.runtime_mem -= self.region_list[reg_id].param_size
self.bwd_reg_to_offl_waiting.clear()
self.iter_end_time = max(
self.last_comp.end_time, self.last_d2h.end_time)
self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time)
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
......@@ -258,10 +237,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
pref_end_time = pref_start_time + \
2.0 * self._get_communication_overhead('h2d', region.param_size)
pref_ep = ExecutionPeriod(
start_time=pref_start_time, end_time=pref_end_time)
pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size)
pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time)
if is_fwd:
self.fwd_reg_to_pref[region.r_id] = pref_ep
else:
......@@ -276,18 +253,16 @@ class AsynTrainingSimulator(TrainingSimulator):
if is_fwd:
reg_to_comp = self.fwd_reg_to_comp
reg_to_pref = self.fwd_reg_to_pref
flop_key = 'fwd_flop'
flop_key = "fwd_flop"
else:
reg_to_comp = self.bwd_reg_to_comp
reg_to_pref = self.bwd_reg_to_pref
flop_key = 'bwd_flop'
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(
region.r_id, ExecutionPeriod(0, 0)).end_time)
comp_end_time = comp_start_time + \
sum([self._get_computing_overhead(node.meta.get(flop_key, 0))
for node in region.nodes])
comp_ep = ExecutionPeriod(
start_time=comp_start_time, end_time=comp_end_time)
flop_key = "bwd_flop"
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time)
comp_end_time = comp_start_time + sum(
[self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes]
)
comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time)
reg_to_comp[region.r_id] = comp_ep
self.last_comp = comp_ep
......@@ -297,10 +272,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
offl_end_time = offl_start_time + \
self._get_communication_overhead('d2h', region.param_size)
offl_ep = ExecutionPeriod(
start_time=offl_start_time, end_time=offl_end_time)
offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size)
offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time)
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
self.last_d2h = offl_ep
......@@ -332,20 +305,17 @@ class AsynTrainingSimulator(TrainingSimulator):
self.fwd_reg_flow[region.r_id, region.r_id] = True
else:
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
self.fwd_reg_flow[region.r_id,
self.reg_buffer_to_free] = False
self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
self.reg_buffer_to_free.clear()
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self.runtime_mem += fwd_prefetch_region.param_size
self.fwd_reg_flow[region.r_id,
fwd_prefetch_region.r_id] = True
self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \
calculate_fwd_out(node)
self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
......@@ -354,8 +324,7 @@ class AsynTrainingSimulator(TrainingSimulator):
if region.need_offload:
self.runtime_mem -= region.param_size
assert len(
self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}"
self.reg_buffer_to_free.append(region.r_id)
def _eval_bwd_cost_per_region(self, region: Region):
......@@ -398,8 +367,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
else:
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
self.bwd_reg_flow[region.r_id,
self.reg_buffer_to_free] = False
self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
# free gradients in the buffer
while len(self.reg_buffer_to_free):
......@@ -415,8 +383,7 @@ class AsynTrainingSimulator(TrainingSimulator):
bwd_prefetch_region = region.bwd_prefetch_region
if bwd_prefetch_region:
self.runtime_mem += bwd_prefetch_region.param_size
self.bwd_reg_flow[region.r_id,
bwd_prefetch_region.r_id] = True
self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True
# add the gradient of the parameter
if region.r_id < region.shared_rid:
......@@ -426,10 +393,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
node.meta['bwd_mem_out']
self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
......@@ -437,8 +402,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
calculate_fwd_tmp(node))
self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
......@@ -446,12 +410,14 @@ class AsynTrainingSimulator(TrainingSimulator):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out']
self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
raise ValueError(
f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!")
f"runtime memory computed less than 0, which is miscalculated!"
)
# release parameters of the region
if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
......
......@@ -35,7 +35,6 @@ class NvDevicePower:
class GlobalRuntimeInfo(metaclass=SingletonMeta):
def __init__(self):
self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream()
......@@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
# forward
for region in region_list:
for node in region.nodes:
runtime_mem = runtime_mem + \
calculate_fwd_tmp(node) + calculate_fwd_out(node)
runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
act_peak_mem = max(runtime_mem, act_peak_mem)
# backward
bwd_deps = {}
for region in region_list.__reversed__():
for node in region.nodes.__reversed__():
runtime_mem -= calculate_fwd_out(node)
runtime_mem = runtime_mem + \
node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out']
runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
act_peak_mem = max(runtime_mem, act_peak_mem)
runtime_mem = runtime_mem - \
node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node)
runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node)
# free bwd_mem_out
bwd_deps[node] = len(node.all_input_nodes)
......@@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
if user_node in bwd_deps:
bwd_deps[user_node] -= 1
if bwd_deps[user_node] <= 0:
runtime_mem -= user_node.meta['bwd_mem_out']
runtime_mem -= user_node.meta["bwd_mem_out"]
return act_peak_mem
......@@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float:
def requires_upload_p_in_fwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
)
def requires_release_p_in_bwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
)
def requires_offload_g_in_bwd(region: Region):
......
......@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
def _construct_shard_meta_info(
node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec
) -> ShardMetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
origin_sharding_spec, target_sharding_spec
)
meta_info = ShardMetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data"))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
......@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
meta_info.memory_cost = mem_cost
# get computation cost for ShardMetaInfo
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
meta_info.compute_cost = TrainCycleItem(
total_cost["forward"] * element_length,
total_cost["backward"] * element_length,
total_cost["total"] * element_length,
)
# get tensor shape for ShardMetaInfo
origin_sharding_spec: ShardingSpec
......@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
return meta_info
......@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
user_node_index]
origin_sharding_spec, target_sharding_spec = (
origin_spec_dict[node_index],
sharding_spec_dict[node_index][user_node_index],
)
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
......@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
# this case is for all_reduce, there will be no memory cost
meta_info = ShardMetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
output_node = next(n for n in node.users if hasattr(n, "_meta_data"))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
meta_info.compute_cost = TrainCycleItem(
total_cost["forward"] * element_length,
total_cost["backward"] * element_length,
total_cost["total"] * element_length,
)
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
else:
# this case will be handled by shape consistency manager
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
'tgt_spec']
origin_sharding_spec, target_sharding_spec = (
comm_action.comm_spec["src_spec"],
comm_action.comm_spec["tgt_spec"],
)
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
comm_actions_dict: Dict) -> GraphModule:
def comm_metainfo_pass(
gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict
) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm
......@@ -21,16 +21,15 @@ def _normalize_tuple(x):
@compatibility(is_backward_compatible=False)
class MetaInfoProp:
def __init__(self, module: GraphModule) -> None:
self.module = module
self.func_dict = {
'placeholder': self.placeholder_handler,
'get_attr': self.get_attr_handler,
'output': self.output_handler,
'call_function': self.node_handler,
'call_module': self.node_handler,
'call_method': self.node_handler,
"placeholder": self.placeholder_handler,
"get_attr": self.get_attr_handler,
"output": self.output_handler,
"call_function": self.node_handler,
"call_module": self.node_handler,
"call_method": self.node_handler,
}
def _set_data_ptr(self, x):
......@@ -46,7 +45,7 @@ class MetaInfoProp:
"""
Check if the node is inplace operation.
"""
if node.op == 'call_module':
if node.op == "call_module":
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
......@@ -66,7 +65,7 @@ class MetaInfoProp:
Handle the placeholder node.
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
out = _normalize_tuple(getattr(node, "_meta_data", None))
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
......@@ -96,7 +95,7 @@ class MetaInfoProp:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_strategy_info
meta_info: ShardMetaInfo
......@@ -126,7 +125,8 @@ class MetaInfoProp:
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None
)
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr
......@@ -148,7 +148,7 @@ class MetaInfoProp:
graph_info.fwd_tmp = buffer_tensors
graph_info.fwd_out = output_tensors
# fetch other memory informations
# fetch other memory information
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.fwd_mem_out = memory_cost.fwd.activation
......
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
......@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
user_node_index: int):
def runtime_apply_for_iterable_object(
node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int
):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
for index, (origin_sharding_spec,
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
input_dict[node_index][user_node_index])):
for index, (origin_sharding_spec, target_sharding_spec) in enumerate(
zip(origin_dict[node_index], input_dict[node_index][user_node_index])
):
rst.append(
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
target_sharding_spec))
shape_consistency_manager.apply_for_autoparallel_runtime(
node[index], origin_sharding_spec, target_sharding_spec
)
)
rst = type(node)(rst)
return rst
......@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
if isinstance(comm_action.comm_spec, CommSpec):
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
else:
origin_sharding_spec = comm_action.comm_spec['src_spec']
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
origin_sharding_spec = comm_action.comm_spec["src_spec"]
tgt_sharding_spec = comm_action.comm_spec["tgt_spec"]
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst
......@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
node_to_index_dict = {}
index = 0
for node in nodes:
if node.target == 'sharding_spec_convert_dict':
if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
if node.target == 'origin_node_sharding_spec_dict':
if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
if node.target == 'comm_actions_dict':
if node.target == "comm_actions_dict":
comm_actions_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
......@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output':
if not hasattr(node, "best_strategy") or node.op == "output":
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
node.target_sharding_specs,
(list,
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
node.target_sharding_specs, (list, tuple)
), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
total_difference = 0
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
node.target_sharding_specs[user_node_index]):
for sharding_spec, target_sharding_spec in zip(
node.sharding_spec, node.target_sharding_specs[user_node_index]
):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
shape_consistency_node = mod_graph.create_node(
"call_function",
runtime_apply_for_iterable_object,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
)
else:
assert isinstance(node.sharding_spec,
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
assert isinstance(
node.sharding_spec, ShardingSpec
), "node.sharding_spec should be type of ShardingSpec, tuple or list."
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
shape_consistency_node = mod_graph.create_node(
"call_function",
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
MetaInfo(shape_consistency_node,
mod_dir=user_node.meta['info'].mod_dir,
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
)
if hasattr(user_node.meta["info"], "activation_checkpoint"):
MetaInfo(
shape_consistency_node,
mod_dir=user_node.meta["info"].mod_dir,
activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint),
)
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
......@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output':
if not hasattr(node, "best_strategy") or node.op == "output":
continue
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
......@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
comm_spec_apply_node = mod_graph.create_node(
"call_function",
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
)
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
......@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
comm_spec_apply_node = mod_graph.create_node(
"call_function",
runtime_comm_spec_apply,
args=(node, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
)
user_list = list(node.users.keys())
for user in user_list:
if user == comm_spec_apply_node:
......@@ -211,15 +212,17 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
if hasattr(node.meta['info'], 'activation_checkpoint'):
MetaInfo(comm_spec_apply_node,
mod_dir=node.meta['info'].mod_dir,
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
if hasattr(node.meta["info"], "activation_checkpoint"):
MetaInfo(
comm_spec_apply_node,
mod_dir=node.meta["info"].mod_dir,
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
)
return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule):
def _act_annotation_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
......@@ -227,21 +230,21 @@ def _act_annotataion_pass(gm: torch.fx.GraphModule):
nodes = tuple(mod_graph.nodes)
for node in nodes:
if not hasattr(node.meta, 'activation_checkpoint'):
from .runtime_preparation_pass import size_processing
if not hasattr(node.meta, "activation_checkpoint"):
pass
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
if 'activation_checkpoint' in user_node.meta:
user_act_annotation = user_node.meta['activation_checkpoint']
if "activation_checkpoint" in user_node.meta:
user_act_annotation = user_node.meta["activation_checkpoint"]
break
for input_node in node._input_nodes.keys():
if 'activation_checkpoint' in input_node.meta:
input_act_annotation = input_node.meta['activation_checkpoint']
if "activation_checkpoint" in input_node.meta:
input_act_annotation = input_node.meta["activation_checkpoint"]
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
node.meta['activation_checkpoint'] = user_act_annotation
node.meta["activation_checkpoint"] = user_act_annotation
return gm
......
import operator
from copy import deepcopy
from typing import Dict, List, Union
import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationDataType,
ShardingStrategy,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
......@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS
shape_consistency_manager = ShapeConsistencyManager()
def size_processing(size: Union[int, torch.Size],
def size_processing(
size: Union[int, torch.Size],
dim_partition_dict: Dict[int, List[int]],
device_mesh_info: Dict[int, int],
target_dim: int = None,
node_name: str = None):
node_name: str = None,
):
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
......@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
return size
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor):
def solution_annotation_pass(
gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor
):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
......@@ -70,14 +66,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
# stick the solution strategy to the corresponding node
setattr(node, 'best_strategy', strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
setattr(node, "best_strategy", strategies_vector[strategy_index])
setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node))
str(node)
)
# attach the corresponding metainfo if node has the attribute `strategies_info`
if hasattr(node, 'strategies_info'):
setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
if hasattr(node, "strategies_info"):
setattr(node, "best_strategy_info", node.strategies_info[strategy_index])
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
......@@ -92,15 +89,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs)
setattr(node, "target_sharding_specs", target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
if node.op == "get_attr":
assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version."
target_node = node.strategies_vector.successor_nodes[0]
node_name = str(node)
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP:
node_name = str(target_node)
target_node = target_node.strategies_vector.successor_nodes[0]
user_strategy = target_node.best_strategy
......@@ -122,11 +119,11 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
if node.op != "placeholder":
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict")
break
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
......@@ -144,11 +141,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape):
for dim, dim_size in enumerate(device_mesh.shape):
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):
'''
"""
A helper function to extract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
......@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
'''
"""
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
......@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
return target_dim
def _post_processing(node, size_processing_node):
'''
"""
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
'''
# store original node and processing node pair in node_pairs dictioanry
"""
# store original node and processing node pair in node_pairs dictionary
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
if hasattr(node.meta['info'], 'activation_checkpoint'):
MetaInfo(size_processing_node,
mod_dir=node.meta['info'].mod_dir,
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
if hasattr(node.meta["info"], "activation_checkpoint"):
MetaInfo(
size_processing_node,
mod_dir=node.meta["info"].mod_dir,
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
)
user_list = list(node.users.keys())
for user in user_list:
......@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
user.kwargs = new_kwargs
def _update_slice_object_args(slice_object):
'''
"""
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
'''
"""
if isinstance(slice_object, slice):
start = slice_object.start
stop = slice_object.stop
......@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
for node in nodes:
if node.op == 'call_method' and node.target == 'size':
if node.op == "call_method" and node.target == "size":
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
......@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# insert size_processing node
with mod_graph.inserting_after(node):
size_processing_node = mod_graph.create_node('call_function',
size_processing_node = mod_graph.create_node(
"call_function",
size_processing,
args=(node, dim_partition_dict, device_mesh_info,
target_dim, node.name))
args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name),
)
_post_processing(node, size_processing_node)
if node.op == 'call_function' and node.target == operator.getitem:
if node.op == "call_function" and node.target == operator.getitem:
getitem_index = node.args[1]
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
......@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
nodes = tuple(mod_graph.nodes)
def _extract_info_from_sharding_spec(sharding_spec):
'''
"""
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
'''
"""
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
assert isinstance(
sharding_spec, (tuple, list)
), "sharding_spec should be type of ShardingSpec, tuple, list or None"
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
......@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
else:
new_args.append(arg)
else:
assert isinstance(arg,
(int, tuple, list)), 'The argument in view node should be either type of Node or int.'
assert isinstance(
arg, (int, tuple, list)
), "The argument in view node should be either type of Node or int."
if isinstance(arg, (tuple, list)):
new_args.extend(arg)
else:
......@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
new_args = _process_node_arguments(node)
if node.op == 'call_method':
if node.op == "call_method":
args_to_process = list(new_args[1:])
else:
args_to_process = list(new_args)
......@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
args_to_process = tuple(args_to_process)
if node.op == 'call_method':
if node.op == "call_method":
new_args = (new_args[0],) + args_to_process
else:
new_args = args_to_process
......@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
node.args = new_args
def _filter_node_with_shape_args(node):
if node.op == 'call_method':
if node.op == "call_method":
target = getattr(node.args[0]._meta_data.__class__, node.target)
elif node.op == 'call_function':
elif node.op == "call_function":
target = node.target
else:
target = None
......@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
if not hasattr(node, "sharding_spec"):
continue
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
......@@ -388,19 +388,25 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
# This stream is created for overlaping the communication and computation.
# This stream is created for overlapping the communication and computation.
reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param, name=None):
comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action, name):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
if (
node.op == "call_module"
and op_data.type == OperationDataType.PARAM
and op_data.name == name
and comm_action.comm_type == CommType.HOOK
):
return True
if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
if (
node.op == "get_attr"
and isinstance(node._meta_data, torch.nn.parameter.Parameter)
and comm_action.comm_type == CommType.HOOK
):
return True
return False
......@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
......@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
setattr(param, "sharding_spec", origin_sharding_spec)
# TODO: build a ColoParameter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
shape_consistency_manager.apply_for_autoparallel_runtime(
param.data, param.sharding_spec, target_sharding_spec
)
.detach()
.clone()
)
return param
for node in nodes:
if node.op == 'call_module':
if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
if hasattr(target_module, 'processed') and target_module.processed:
if hasattr(target_module, "processed") and target_module.processed:
continue
setattr(target_module, 'processed', True)
setattr(target_module, "processed", True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
param = _shard_param(param, target_sharding_spec)
......@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of buffers
for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
setattr(buffer, 'sharding_spec', origin_sharding_spec)
setattr(buffer, "sharding_spec", origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
sharded_buffer_dict[name] = buffer_sharded
......@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
if node.op == 'get_attr':
if node.op == "get_attr":
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
......@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
"""
replace the origin kernel into kernel with implicit communication inside.
"""
pass
def runtime_preparation_pass(gm: torch.fx.GraphModule,
def runtime_preparation_pass(
gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
gm, solution, strategies_constructor)
overlap=False,
):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
gm, solution, strategies_constructor
)
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment