Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
from .alpha_beta_profiler import AlphaBetaProfiler
from .calc_pipeline_strategy import alpa_dp
__all__ = ['AlphaBetaProfiler', 'alpa_dp']
import math
import time
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
GB = int((1 << 30))
BYTE = 4
FRAMEWORK_LATENCY = 0
class AlphaBetaProfiler:
'''
Profile alpha and beta value for a given device list.
Usage:
# Note: the environment of execution is supposed to be
# multi-process with multi-gpu in mpi style.
>>> physical_devices = [0, 1, 4, 5]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> ab_dict = profiler.alpha_beta_dict
>>> print(ab_dict)
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
'''
def __init__(self,
physical_devices: List[int],
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
ctype: str = 'a',
warmup: int = 5,
repeat: int = 25,
latency_iters: int = 5,
homogeneous_tolerance: float = 0.1):
'''
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
ctype: 'a' for all-reduce, 'b' for broadcast.
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
'''
self.physical_devices = physical_devices
self.ctype = ctype
self.world_size = len(physical_devices)
self.warmup = warmup
self.repeat = repeat
self.latency_iters = latency_iters
self.homogeneous_tolerance = homogeneous_tolerance
self.process_group_dict = None
self._init_profiling()
if alpha_beta_dict is None:
self.alpha_beta_dict = self.profile_ab()
else:
self.alpha_beta_dict = alpha_beta_dict
def _init_profiling(self):
# Create process group list based on its global rank
process_group_list = []
for f_index in range(self.world_size - 1):
for b_index in range(f_index + 1, self.world_size):
process_group_list.append((self.physical_devices[f_index], self.physical_devices[b_index]))
# Create process group dict which maps process group to its handler
process_group_dict = {}
for process_group in process_group_list:
pg_handler = dist.new_group(process_group)
process_group_dict[process_group] = pg_handler
self.process_group_dict = process_group_dict
def _profile(self, process_group, pg_handler, nbytes):
logger = get_dist_logger()
rank = dist.get_rank()
src_device_num = process_group[0]
world_size = len(process_group)
device = torch.cuda.current_device()
buf = torch.randn(nbytes // 4).to(device)
torch.cuda.synchronize()
# warmup
for _ in range(self.warmup):
if self.ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
elif self.ctype == "b":
dist.broadcast(buf, src=src_device_num, group=pg_handler)
torch.cuda.synchronize()
dist.barrier(group=pg_handler)
begin = time.perf_counter()
for _ in range(self.repeat):
if self.ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
elif self.ctype == "b":
dist.broadcast(buf, src=src_device_num, group=pg_handler)
torch.cuda.synchronize()
end = time.perf_counter()
dist.barrier(group=pg_handler)
if rank == src_device_num:
avg_time_s = (end - begin) / self.repeat - FRAMEWORK_LATENCY
alg_band = nbytes / avg_time_s
if self.ctype == "a":
# convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware.
bus_band = 2 * (world_size - 1) / world_size * alg_band
bus_band = alg_band
elif self.ctype == "b":
bus_band = alg_band
logger.info(
f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s"
)
return (avg_time_s, alg_band)
else:
# Just a placeholder
return (None, None)
def profile_latency(self, process_group, pg_handler):
'''
This function is used to profile the latency of the given process group with a series of bytes.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
'''
latency_list = []
for i in range(self.latency_iters):
nbytes = int(BYTE << i)
(t, _) = self._profile(process_group, pg_handler, nbytes)
latency_list.append(t)
if latency_list[0] is None:
latency = None
else:
median_index = math.floor(self.latency_iters / 2)
latency = latency_list[median_index]
return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
'''
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
'''
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
return bandwidth
def profile_ab(self):
'''
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
'''
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
global_pg_handler = dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
device = torch.cuda.current_device()
rank_max_nbytes = torch.cuda.mem_get_info(device)[0]
rank_max_nbytes = torch.tensor(rank_max_nbytes, device=device)
dist.all_reduce(rank_max_nbytes, op=dist.ReduceOp.MIN, group=pg_handler)
max_nbytes = min(int(1 * GB), int(GB << int(math.log2(rank_max_nbytes.item() / GB))))
return max_nbytes
for process_group, pg_handler in self.process_group_dict.items():
if rank not in process_group:
max_nbytes = None
alpha = None
bandwidth = None
else:
max_nbytes = get_max_nbytes(process_group, pg_handler)
alpha = self.profile_latency(process_group, pg_handler)
bandwidth = self.profile_bandwidth(process_group, pg_handler, maxbytes=max_nbytes)
if bandwidth is None:
beta = None
else:
beta = 1 / bandwidth
broadcast_list = [alpha, beta]
dist.broadcast_object_list(broadcast_list, src=process_group[0])
alpha_beta_dict[process_group] = tuple(broadcast_list)
# add symmetry pair to the apha_beta_dict
symmetry_ab_dict = {}
for process_group, alpha_beta_pair in alpha_beta_dict.items():
symmetry_process_group = (process_group[1], process_group[0])
symmetry_ab_dict[symmetry_process_group] = alpha_beta_pair
alpha_beta_dict.update(symmetry_ab_dict)
return alpha_beta_dict
def search_best_logical_mesh(self):
'''
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict
are homogeneous if the beta value is close enough.
2. Find the best homogeneous device group contains all the physical devices. The best homogeneous
device group means the lowest beta value in the groups which contains all the physical devices.
And the reason we require the group contains all the physical devices is that the devices not in
the group will decrease the bandwidth of the group.
3. If the best homogeneous device group is found, we will construct the largest ring for each device
based on the best homogeneous device group, and the best logical mesh will be the union of all the
rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for
4 devices.
Returns:
best_logical_mesh: The best logical mesh for the given device list.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
'''
def _power_of_two(integer):
return integer & (integer - 1) == 0
def _detect_homogeneous_device(alpha_beta_dict):
'''
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
'''
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
for process_group, (_, beta) in alpha_beta_dict.items():
if homogeneous_device_dict is None:
homogeneous_device_dict[beta] = []
homogeneous_device_dict[beta].append(process_group)
match_beta = None
for beta_value in homogeneous_device_dict.keys():
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
1 - self.homogeneous_tolerance):
match_beta = beta_value
break
if match_beta is not None:
homogeneous_device_dict[match_beta].append(process_group)
else:
homogeneous_device_dict[beta] = []
homogeneous_device_dict[beta].append(process_group)
return homogeneous_device_dict
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
'''
This function is used to check whether the homogeneous_group contains all physical devices.
'''
flatten_mesh = []
for process_group in homogeneous_group:
flatten_mesh.extend(process_group)
non_duplicated_flatten_mesh = set(flatten_mesh)
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
'''
This function is used to construct the largest ring in the homogeneous_group for each rank.
'''
# Construct the ring
ring = []
ranks_in_ring = []
for rank in self.physical_devices:
if rank in ranks_in_ring:
continue
stable_status = False
ring_for_rank = []
ring_for_rank.append(rank)
check_rank_list = [rank]
rank_to_check_list = []
while not stable_status:
stable_status = True
check_rank_list.extend(rank_to_check_list)
rank_to_check_list = []
for i in range(len(check_rank_list)):
check_rank = check_rank_list.pop()
for process_group in homogeneous_group:
if check_rank in process_group:
rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
if rank_to_append not in ring_for_rank:
stable_status = False
rank_to_check_list.append(rank_to_append)
ring_for_rank.append(rank_to_append)
ring.append(ring_for_rank)
ranks_in_ring.extend(ring_for_rank)
return ring
assert _power_of_two(self.world_size)
power_of_two = int(math.log2(self.world_size))
median = power_of_two // 2
balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
balanced_logical_mesh = []
for row_index in range(row_size):
balanced_logical_mesh.append([])
for column_index in range(column_size):
balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index])
homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict)
beta_list = [b for b in homogeneous_device_dict.keys()]
beta_list.sort()
beta_list.reverse()
homogeneous_types = len(beta_list)
best_logical_mesh = None
if homogeneous_types >= 2:
for _ in range(homogeneous_types - 1):
lowest_beta = beta_list.pop()
best_homogeneous_group = homogeneous_device_dict[lowest_beta]
# if the best homogeneous group contains all physical devices,
# we will build the logical device mesh based on it. Otherwise,
# we will check next level homogeneous group.
if _check_contain_all_devices(best_homogeneous_group):
# We choose the largest ring for each rank to maximum the best bus utilization.
best_logical_mesh = _construct_largest_ring(best_homogeneous_group)
break
if homogeneous_types == 1 or best_logical_mesh is None:
# in this case, we use balanced logical mesh as the best
# logical mesh.
best_logical_mesh = balanced_logical_mesh
return best_logical_mesh
def extract_alpha_beta_for_device_mesh(self):
'''
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
>>> print(mesh_alpha)
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
'''
best_logical_mesh = self.search_best_logical_mesh()
first_axis = [row[0] for row in best_logical_mesh]
second_axis = best_logical_mesh[0]
# init process group for both axes
first_axis_process_group = dist.new_group(first_axis)
second_axis_process_group = dist.new_group(second_axis)
# extract alpha and beta for both axes
def _extract_alpha_beta(pg, pg_handler):
latency = self.profile_latency(pg, pg_handler)
bandwidth = self.profile_bandwidth(pg, pg_handler)
broadcast_object = [latency, bandwidth]
dist.broadcast_object_list(broadcast_object, src=pg[0])
return broadcast_object
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
mesh_alpha = [first_latency, second_latency]
mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth]
return mesh_alpha, mesh_beta
from math import pow
import numpy as np
def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
submesh_choices = []
i = 1
p = -1
while i <= num_devices_per_host:
i *= 2
p += 1
assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
f"while now num_devices_per_host = {num_devices_per_host}")
if mode == "alpa":
for i in range(p + 1):
submesh_choices.append((1, pow(2, i)))
for i in range(2, num_hosts + 1):
submesh_choices.append((i, num_devices_per_host))
elif mode == "new":
for i in range(p // 2 + 1):
for j in range(i, p - i + 1):
submesh_choices.append((pow(2, i), pow(2, j)))
return submesh_choices
def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
best_configs):
"""Implementation of Alpa DP for pipeline strategy
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
Arguments:
num_layers: K
num_devices: N*M
num_microbatches: B
submesh_choices: List[(n_i,m_i)]
compute_cost: t_intra
"""
# For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32)
f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32)
f[0, num_layers, 0] = 0
for s in range(1, num_layers + 1):
for k in range(num_layers - 1, -1, -1):
for d in range(1, num_devices + 1):
for m, submesh in enumerate(submesh_choices):
n_submesh_devices = np.prod(np.array(submesh))
if n_submesh_devices <= d:
# TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete.
# if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]:
# ...
for i in range(num_layers, k, -1):
stage_cost = compute_cost[k, i, m]
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
f[s, k, d] = new_cost
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
best_s = -1
best_total_cost = np.inf
for s in range(1, num_layers + 1):
if f[s, 0, num_devices] < best_total_cost:
best_s = s
best_total_cost = f[s, 0, num_devices]
if np.isinf(best_total_cost):
return np.inf, None
total_cost = f[best_s, 0, num_devices] + (num_microbatches - 1) * f_stage_max[best_s, 0, num_devices]
current_s = best_s
current_layer = 0
current_devices = num_devices
res = []
while current_s > 0 and current_layer < num_layers and current_devices > 0:
next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
assert next_start_layer != -1 and current_devices != -1
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
current_s -= 1
current_layer = next_start_layer
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
return total_cost, res
def alpa_dp(num_layers,
num_devices,
num_microbatches,
submesh_choices,
num_autosharding_configs,
compute_cost,
gap=1e-6):
"""Alpa auto stage dynamic programming.
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments:
submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
num_autosharding_configs), "Cost shape wrong."
all_possible_stage_costs = np.sort(np.unique(compute_cost))
best_cost = np.inf
best_solution = None
last_max_stage_cost = 0.0
# TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost?
# In dp_impl it seems the argmin n_config will be chosen. Just amin here.
best_configs = np.argmin(compute_cost, axis=3)
best_compute_cost = np.amin(compute_cost, axis=3)
assert len(all_possible_stage_costs), "no solution in auto stage construction."
for max_stage_cost in all_possible_stage_costs:
if max_stage_cost * num_microbatches >= best_cost:
break
if max_stage_cost - last_max_stage_cost < gap:
continue
cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
max_stage_cost, best_configs)
if cost < best_cost:
best_cost = cost
best_solution = solution
last_max_stage_cost = max_stage_cost
return best_cost, best_solution
from functools import reduce
import operator
from functools import reduce
import torch
import torch.distributed as dist
......@@ -11,7 +12,7 @@ class DeviceMesh:
can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
own latency and bandwidth. We use alpha-beta model to model the
communication cost.
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
......@@ -23,6 +24,7 @@ class DeviceMesh:
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
"""
def __init__(self,
......@@ -49,8 +51,11 @@ class DeviceMesh:
self.need_flatten = need_flatten
if self.init_process_group:
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
if self.need_flatten:
if self.need_flatten and self._logical_mesh_id.dim() > 1:
self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
self.mesh_beta)
@property
def shape(self):
......@@ -64,6 +69,18 @@ class DeviceMesh:
def logical_mesh_id(self):
return self._logical_mesh_id
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
setattr(result, k, v)
return result
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
......@@ -90,7 +107,7 @@ class DeviceMesh:
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
......@@ -186,3 +203,38 @@ class DeviceMesh:
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
class FlattenDeviceMesh(DeviceMesh):
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
super().__init__(physical_mesh_id,
mesh_shape,
mesh_alpha,
mesh_beta,
init_process_group=False,
need_flatten=False)
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
self.mesh_alpha = max(self.mesh_alpha)
self.mesh_beta = min(self.mesh_beta)
# Different from original process_groups_dict, rank_list is not stored
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
def create_process_numbers_for_logical_mesh(self):
'''
Build 1d DeviceMesh in column-major(0) and row-major(1)
for example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
num_devices = reduce(operator.mul, self.mesh_shape, 1)
process_numbers_dict = {}
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
return process_numbers_dict
def mix_gather_cost(self, num_bytes):
num_devices = reduce(operator.mul, self.mesh_shape, 1)
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
from ._compatibility import compatibility, is_compatible_with_meta
from .graph_module import ColoGraphModule
from .passes import MetaInfoProp
from .tracer import ColoTracer, meta_trace
from ._compatibility import compatibility, is_compatible_with_meta
from .graph_module import ColoGraphModule
from .passes import MetaInfoProp, metainfo_trace
from .tracer import ColoTracer, meta_trace, symbolic_trace
......@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch.utils._pytree import tree_map
......@@ -163,6 +163,23 @@ def meta_conv(
return out
@register_meta(aten._convolution.default)
def meta_conv_1(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args
):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
......@@ -179,6 +196,79 @@ def meta_adaptive_avg_pool2d_backward(
return grad_input
# ================================ RNN =============================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn.default)
def meta_cuda_rnn(
input,
weight,
weight_stride0,
weight_buf,
hx,
cx,
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
dropout,
train,
bidirectional,
batch_sizes,
dropout_state,
):
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
mini_batch = batch_sizes[0]
batch_sizes_sum = input.shape[0]
else:
seq_length = input.shape[1] if batch_first else input.shape[0]
mini_batch = input.shape[0] if batch_first else input.shape[1]
batch_sizes_sum = -1
num_directions = 2 if bidirectional else 1
out_size = proj_size if proj_size != 0 else hidden_size
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = (
[mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape)
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
reserve_shape = 0 if train else 0
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
return output, hy, cy, reserve, weight_buf
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs):
print(input, weight, hx, cx)
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_hx = torch.empty_like(hx)
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta')
return grad_input, grad_weight, grad_hx, grad_cx
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
@register_meta(aten.relu.default)
......@@ -186,6 +276,11 @@ def meta_relu(input: torch.Tensor):
return torch.empty_like(input)
@register_meta(aten.prelu.default)
def meta_prelu(input: torch.Tensor, weight: torch.Tensor):
return torch.empty_like(input)
@register_meta(aten.hardswish.default)
def meta_hardswish(input: torch.Tensor):
return torch.empty_like(input)
......@@ -278,12 +373,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
# ================================== Misc ==========================================
#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return input
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
@register_meta(aten._local_scalar_dense.default)
def meta_local_scalar_dense(self: torch.Tensor):
return 0
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
......@@ -317,7 +418,7 @@ def meta_index_Tensor(self, indices):
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs # avoid import cycle in mypy
import torch._refs as refs
indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
......
import colossalai
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
from typing import List, Callable, Any, Tuple, Dict, Iterable
import colossalai
try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import (
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_args,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = False
if CODEGEN_AVAILABLE:
......@@ -27,7 +50,7 @@ def _gen_saved_tensors_hooks():
return (x.device, x.cpu())
else:
return x
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
return (x.device, x.cpu())
......@@ -48,11 +71,9 @@ def pack_hook_no_input(self, x):
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
"""Generate customized saved_tensors_hooks
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
str: generated context
"""
......@@ -111,8 +132,8 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
act_ckpt_label = node.activation_checkpoint
if 'activation_checkpoint' in node.meta:
act_ckpt_label = node.meta['activation_checkpoint']
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
......@@ -129,7 +150,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
elif current_region is not None and not 'activation_checkpoint' in node.meta:
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
......@@ -144,7 +165,7 @@ def _find_ckpt_regions(nodes: List[Node]):
def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
In pofo algorithm, during annotation, we will annotate the offload region with the
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
region's index, offload_input is a bool type indicates whether we need to offload
the input, offload_bar is a bool type indicates whether we need to offload all the
......@@ -157,8 +178,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable):
act_offload_label = node.activation_offload
if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
act_offload_label = node.meta['activation_offload']
if current_region == None:
current_region = act_offload_label
......@@ -212,18 +233,16 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
"""Check if the node could end the ckpt region
Args:
node (Node): torch.fx.Node
check_idx (int): the index of checkpoint level for
check_idx (int): the index of checkpoint level for
nested checkpoint
Returns:
bool
"""
if hasattr(node, "activation_checkpoint"):
if isinstance(node.activation_checkpoint, list):
return node.activation_checkpoint[check_idx] == None
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], list):
return node.meta['activation_checkpoint'][check_idx] == None
else:
return False
else:
......@@ -232,7 +251,7 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
def _find_nested_ckpt_regions(nodes, check_idx=0):
"""
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
will be list of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_regions = []
......@@ -241,11 +260,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
if isinstance(getattr(node, 'activation_checkpoint'), int):
act_ckpt_label = node.activation_checkpoint
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], int):
act_ckpt_label = node.meta['activation_checkpoint']
else:
act_ckpt_label = node.activation_checkpoint[check_idx]
act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
......@@ -287,7 +306,6 @@ def emit_ckpt_func(body,
level=0,
in_ckpt=False):
"""Emit ckpt fuction in nested way
Args:
body: forward code, in recursive calls, this part will be checkpoint
functions code
......@@ -303,8 +321,8 @@ def emit_ckpt_func(body,
inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method
if isinstance(node_list[0].activation_checkpoint, int):
label = node_list[0].activation_checkpoint
if isinstance(node_list[0].meta['activation_checkpoint'], int):
label = node_list[0].meta['activation_checkpoint']
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
for node in node_list:
......@@ -313,7 +331,7 @@ def emit_ckpt_func(body,
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)
......@@ -322,12 +340,12 @@ def emit_ckpt_func(body,
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
# if there is more level to fetch
if level + 1 < len(node_list[0].activation_checkpoint):
if level + 1 < len(node_list[0].meta['activation_checkpoint']):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
......@@ -354,7 +372,7 @@ def emit_ckpt_func(body,
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func += ckpt_func_buffer
activation_offload = getattr(node_list[0], "activation_offload", False)
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
......@@ -368,7 +386,7 @@ def emit_ckpt_func(body,
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
......@@ -379,7 +397,6 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
Args:
body: forward code
ckpt_func: checkpoint functions code
......@@ -564,8 +581,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
if hasattr(node_list[start_node_idx], 'activation_offload'):
activation_offload = node_list[start_node_idx].activation_offload
if 'activation_offload' in node_list[start_node_idx].meta:
activation_offload = node_list[start_node_idx].meta['activation_offload']
else:
activation_offload = False
......@@ -577,8 +594,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if input_node.op != "placeholder":
non_leaf_input = 1
for user in input_node.users:
if hasattr(user, "activation_checkpoint"):
if user.activation_checkpoint == label:
if 'activation_checkpoint' in user.meta:
if user.meta['activation_checkpoint'] == label:
if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
......@@ -616,10 +633,8 @@ if CODEGEN_AVAILABLE:
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
......@@ -796,7 +811,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes):
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
......@@ -829,7 +844,6 @@ if CODEGEN_AVAILABLE:
code = '\n'.join(' ' + line for line in code.split('\n'))
fn_code = f"""
{wrap_stmts}
{prologue}
{code}"""
return PythonCode(fn_code, globals_)
......@@ -851,10 +865,8 @@ else:
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
......@@ -999,7 +1011,7 @@ else:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes):
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
......@@ -1040,7 +1052,6 @@ else:
# in forward function
fn_code = f"""
{wrap_stmts}
{ckpt_func}
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
{code}"""
......
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
from .meta_info_prop import MetaInfoProp
from .concrete_info_prop import ConcreteInfoProp
from .meta_info_prop import MetaInfoProp, metainfo_trace
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
......@@ -9,6 +9,30 @@ def pipe_split():
pass
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgnode_split_pass, simpliy split graph by node number.
"""
mod_graph = gm.graph
avg_num_node = len(mod_graph.nodes) // pp_size
accumulate_num_node = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
accumulate_num_node += 1
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
if node.next.op == 'output':
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In balanced_split_pass, we split module by the size of parameters(weights+bias).
......@@ -37,6 +61,21 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
if pp_size > 1:
node_counter = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == 'placeholder':
continue
elif node_counter == 0:
node_counter += 1
else:
pp_size -= 1
node_counter = 0
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
......@@ -102,7 +141,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
return gm
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False):
# TODO(lyl): use partition IR to assign partition ID to each node.
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
......@@ -114,7 +153,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
part_idx += 1
return part_idx
split_mod = split_module(annotated_gm, None, split_callback)
split_mod = split_module(annotated_gm, None, split_callback, merge_output)
split_submodules = []
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
......
import math
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule, Node
import math
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
__all__ = ['chen_greedy']
......
import math
import sys
from typing import List, Tuple
from colossalai.fx.profiler.memory import calculate_fwd_in
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
from colossalai.logging import get_dist_logger
from .linearize import linearize
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
# global vairable to indicate whether the solver is failed
SOLVER_FAILED = False
......@@ -18,7 +20,7 @@ SOLVER_FAILED = False
# https://gitlab.inria.fr/hiepacs/rotor
# paper link: https://hal.inria.fr/hal-02352969
def _compute_table(chain: Chain, mmax) -> Tuple:
"""Returns the optimal table: a tuple containing:
"""Returns the optimal table: a tuple containing:
Opt[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
......@@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node
Args:
node (List[Node]): List of torch.fx Node,
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
......@@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
# build module if module not found
except ModuleNotFoundError:
import subprocess
import os
import subprocess
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
......
......@@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import torch
import torch.fx
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_flatten
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module
@compatibility(is_backward_compatible=True)
class ConcreteInfoProp(torch.fx.Interpreter):
......@@ -22,17 +23,17 @@ class ConcreteInfoProp(torch.fx.Interpreter):
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
).cuda()
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
gm = symbolic_trace(model)
interp = ConcreteInfoProp(gm)
interp.run(input_sample)
print(interp.summary(unit='kb'))
output of above code is
print(interp.summary(unit='kb'))
output of above code is
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
......@@ -229,8 +230,8 @@ class ConcreteInfoProp(torch.fx.Interpreter):
def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
......
......@@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
import torch
import torch.fx
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
from colossalai.fx.profiler import (
GraphInfo,
activation_size,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_tmp,
profile_function,
profile_method,
profile_module,
)
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
......@@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter):
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
......@@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter):
interp = MetaInfoProp(gm)
interp.run(input_sample)
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
# output of above code is
# output of above code is
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
......@@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter):
def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
......@@ -306,3 +315,38 @@ class MetaInfoProp(torch.fx.Interpreter):
]
return tabulate(node_summaries, headers=headers, stralign='right')
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
"""
MetaInfo tracing API
Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle,
and annotate them on ``gm.graph``.
Uses:
>>> model = ...
>>> gm = symbolic_trace(model)
>>> args = ... # sample input to the ``GraphModule``
>>> metainfo_trace(gm, *args)
Args:
gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo.
verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False.
unit (str, optional): The unit of memory. Defaults to "MB".
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
interp.propagate(*args, **kwargs)
if verbose:
interp.summary(unit)
gm.to('cpu')
del interp
return gm
......@@ -38,11 +38,11 @@ def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int],
merge_output = False,
):
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
Creates subgraphs out of main graph
Args:
m (GraphModule): Graph module to split
root_m (torch.nn.Module): root nn module. Not currently used. Included
......@@ -52,52 +52,40 @@ def split_module(
that maps a given Node instance to a numeric partition identifier.
split_module will use this function as the policy for which operations
appear in which partitions in the output Module.
Returns:
GraphModule: the module after split.
Example:
This is a sample setup:
import torch
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w
# symbolically trace model
my_module = MyModule()
my_module_traced = symbolic_trace(my_module)
# random mod partitioning
partition_counter = 0
NPARTITIONS = 3
def mod_partition(node: Node):
global partition_counter
partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
return partition
# split module in module with submodules
module_with_submodules = split_module(
my_module_traced, my_module, mod_partition
)
Output looks like this. Original graph is broken into partitions
> print(module_with_submodules)
GraphModule(
(submod_0): GraphModule(
......@@ -108,7 +96,6 @@ def split_module(
)
(submod_2): GraphModule()
)
def forward(self, x, y):
param = self.param
submod_0 = self.submod_0(x, param, y); x = param = y = None
......@@ -119,10 +106,8 @@ def split_module(
getitem_3 = submod_1[1]; submod_1 = None
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
return submod_2
Output of split module is the same as output of input traced module.
This is an example within a test setting:
> orig_out = my_module_traced(x, y)
> submodules_out = module_with_submodules(x, y)
> self.assertEqual(orig_out, submodules_out)
......@@ -147,6 +132,29 @@ def split_module(
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output(
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
def_partition.outputs.setdefault(def_node.name)
if use_partition_name is not None:
def_partition.partition_dependents.setdefault(use_partition_name)
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
use_partition.outputs.setdefault(def_node.name)
else:
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.outputs.setdefault(def_node.name)
# split nodes into parititons
for node in m.graph.nodes:
......@@ -155,7 +163,10 @@ def split_module(
if node.op in ["placeholder"]:
continue
if node.op == 'output':
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else:
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
continue
partition_name = str(split_callback(node))
......@@ -235,10 +246,10 @@ def split_module(
for node in m.graph.nodes:
if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()
......@@ -278,4 +289,9 @@ def split_module(
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
for partition_name in sorted_partitions:
partition = partitions[partition_name]
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
return new_gm
import torch
from typing import Dict, Set
from typing import Dict
from torch.fx.node import Node, map_arg
from torch.fx.graph import Graph
def get_comm_size(prev_partition, next_partition):
"""
Given two partitions (parent and child),
......@@ -32,7 +31,6 @@ def get_comm_size(prev_partition, next_partition):
def get_leaf(graph: Graph):
"""
Given a graph, return leaf nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG.
"""
......@@ -57,7 +55,6 @@ def is_leaf(graph: Graph, node: Node):
def get_top(graph: Graph):
"""
Given a graph, return top nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG.
"""
......@@ -100,7 +97,6 @@ def get_all_consumers(graph: Graph, node: Node):
def assign_bfs_level_to_nodes(graph: Graph):
"""
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes.
Example:
class MLP(torch.nn.Module):
def __init__(self, dim: int):
......@@ -110,8 +106,6 @@ def assign_bfs_level_to_nodes(graph: Graph):
self.linear3 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim)
self.linear5 = torch.nn.Linear(dim, dim)
def forward(self, x):
l1 = self.linear1(x)
l2 = self.linear2(x)
......@@ -165,10 +159,8 @@ def assign_bfs_level_to_nodes(graph: Graph):
def get_node_module(node) -> torch.nn.Module:
"""
Find the module associated with the given node.
Args:
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph
Returns:
torch.nn.Module: the module associated with the given node
"""
......@@ -177,3 +169,4 @@ def get_node_module(node) -> torch.nn.Module:
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
module = node.graph.owning_module.get_submodule(node.target)
return module
from .._compatibility import is_compatible_with_meta
if is_compatible_with_meta():
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .opcount import flop_mapping
from .profiler import profile_function, profile_method, profile_module
from .shard_utils import (
calculate_bwd_time,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from .tensor import MetaTensor
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo
from .memory import activation_size, is_inplace, parameter_size
from .memory_utils import activation_size, is_inplace, parameter_size
......@@ -6,7 +6,7 @@ from typing import Dict, List
from torch.fx import Graph, Node
from .._compatibility import compatibility
from .memory import activation_size, is_inplace
from .memory_utils import activation_size, is_inplace
class Phase(Enum):
......@@ -29,7 +29,7 @@ class GraphInfo:
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | \_____ | |
......@@ -80,18 +80,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
Nodes should have attribute `out` indicating the output of each node.
============================================================================
Placeholder ----> p o <---- We need to keep track of grad out
|\________ |
|\________ |
↓ ↘|
f --------> b
|\ \_____ ↑
| \ ↘ /
f f ----> b <---- Not every forward result needs to be saved for backward
| \____ ↑
↘ ↘|
↘ ↘|
f ----> b <---- Backward can be freed as soon as it is required no more.
↘ ↗
l
=============================================================================
=============================================================================
Args:
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
......
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .profiler import profile_function, profile_method, profile_module
from .profiler_function import *
from .profiler_module import *
from .registry import meta_profiler_function, meta_profiler_module
from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
......@@ -5,7 +5,7 @@ import torch
from torch.fx.node import Argument, Target
from ..._compatibility import compatibility
from ..memory import activation_size
from ..memory_utils import activation_size
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
......@@ -27,7 +27,7 @@ class GraphInfo:
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
......@@ -76,14 +76,14 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Unfortunately, backward memory cost and FLOPs are estimated results.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
......@@ -142,13 +142,13 @@ def profile_method(target: 'Target') -> Callable:
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment