Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
......@@ -94,7 +94,7 @@ class ProcessGroupMesh:
return np.unravel_index(rank, shape)
@staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
......@@ -141,8 +141,9 @@ class ProcessGroupMesh:
return list(self._group_to_ranks[group])
@staticmethod
def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int,
indices_at_axis: List[int]) -> List[Tuple[int, ...]]:
def get_coords_along_axis(
base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis.
Args:
......@@ -155,13 +156,12 @@ class ProcessGroupMesh:
"""
coords_in_group = []
for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:])
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group
def create_group_along_axis(self,
axis: int,
indices_at_axis: Optional[List[int]] = None,
backend: Optional[str] = None) -> ProcessGroup:
def create_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
......@@ -186,10 +186,9 @@ class ProcessGroupMesh:
target_group = group
return target_group
def get_group_along_axis(self,
axis: int,
indices_at_axis: Optional[List[int]] = None,
backend: Optional[str] = None) -> ProcessGroup:
def get_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
......
......@@ -3,6 +3,6 @@ from .config import Config, ConfigException
# from .moe_context import MOE_CONTEXT
__all__ = [
'Config',
'ConfigException',
"Config",
"ConfigException",
]
......@@ -5,6 +5,7 @@ import inspect
import sys
from importlib.machinery import SourceFileLoader
from pathlib import Path
from colossalai.logging import get_dist_logger
......@@ -41,7 +42,7 @@ class Config(dict):
self.__setattr__(key, value)
def update(self, config):
assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.'
assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
for k, v in config.items():
self._add_item(k, v)
return self
......@@ -66,11 +67,11 @@ class Config(dict):
elif isinstance(filename, Path):
filepath = filename.absolute()
assert filepath.exists(), f'{filename} is not found, please check your configuration path'
assert filepath.exists(), f"{filename} is not found, please check your configuration path"
# check extension
extension = filepath.suffix
assert extension == '.py', 'only .py files are supported'
assert extension == ".py", "only .py files are supported"
# import the config as module
remove_path = False
......@@ -86,13 +87,13 @@ class Config(dict):
config = Config()
for k, v in module.__dict__.items():
if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v):
if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
continue
else:
config._add_item(k, v)
logger = get_dist_logger()
logger.debug('variables which starts with __, is a module or class declaration are omitted in config file')
logger.debug("variables which starts with __, is a module or class declaration are omitted in config file")
# remove module
del sys.modules[module_name]
......
......@@ -9,14 +9,13 @@ from colossalai.legacy.tensor import ProcessGroup
def _check_sanity():
from colossalai.legacy.core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or "
"pipeline parallel at present.")
raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
class MoeParallelInfo:
"""Moe parallelism information, storing parallel sizes and groups.
"""
"""Moe parallelism information, storing parallel sizes and groups."""
def __init__(self, ep_size: int, dp_size: int):
_check_sanity()
......@@ -61,9 +60,11 @@ class MoeContext(metaclass=SingletonMeta):
self.world_size = dist.get_world_size()
from colossalai.legacy.core import global_context as gpc
self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
assert self.world_size % self.max_ep_size == 0, \
"Maximum expert parallel size must be a factor of the number of GPUs"
self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
assert (
self.world_size % self.max_ep_size == 0
), "Maximum expert parallel size must be a factor of the number of GPUs"
self.min_dp_size = self.world_size // self.max_ep_size
# Enabling kernel optimization may raise error in some cases
......@@ -71,6 +72,7 @@ class MoeContext(metaclass=SingletonMeta):
self.use_kernel_optim = use_kernel_optim
from .random import moe_set_seed
moe_set_seed(seed)
self.has_setup = True
......@@ -88,11 +90,13 @@ class MoeContext(metaclass=SingletonMeta):
number of local experts, the MoeParallelInfo of the current ep_size
"""
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
" is not a multiple of ep size or vice versa."
assert gt_flag or lt_flag, (
"Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts
......
......@@ -16,6 +16,7 @@ class SingletonMeta(type):
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
else:
assert len(args) == 0 and len(
kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
assert (
len(args) == 0 and len(kwargs) == 0
), f"{cls.__name__} is a singleton class and a instance has been created."
return cls._instances[cls]
from .alpha_beta_profiler import AlphaBetaProfiler
from .calc_pipeline_strategy import alpa_dp
__all__ = ['AlphaBetaProfiler', 'alpa_dp']
__all__ = ["AlphaBetaProfiler", "alpa_dp"]
......@@ -13,7 +13,7 @@ FRAMEWORK_LATENCY = 0
class AlphaBetaProfiler:
'''
"""
Profile alpha and beta value for a given device list.
Usage:
......@@ -27,17 +27,19 @@ class AlphaBetaProfiler:
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
'''
def __init__(self,
physical_devices: List[int],
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
ctype: str = 'a',
warmup: int = 5,
repeat: int = 25,
latency_iters: int = 5,
homogeneous_tolerance: float = 0.1):
'''
"""
def __init__(
self,
physical_devices: List[int],
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
ctype: str = "a",
warmup: int = 5,
repeat: int = 25,
latency_iters: int = 5,
homogeneous_tolerance: float = 0.1,
):
"""
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
......@@ -45,7 +47,7 @@ class AlphaBetaProfiler:
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
'''
"""
self.physical_devices = physical_devices
self.ctype = ctype
self.world_size = len(physical_devices)
......@@ -123,7 +125,7 @@ class AlphaBetaProfiler:
return (None, None)
def profile_latency(self, process_group, pg_handler):
'''
"""
This function is used to profile the latency of the given process group with a series of bytes.
Args:
......@@ -132,7 +134,7 @@ class AlphaBetaProfiler:
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
'''
"""
latency_list = []
for i in range(self.latency_iters):
nbytes = int(BYTE << i)
......@@ -148,26 +150,26 @@ class AlphaBetaProfiler:
return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
'''
"""
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
'''
"""
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
return bandwidth
def profile_ab(self):
'''
"""
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
'''
"""
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
global_pg_handler = dist.new_group(self.physical_devices)
dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
......@@ -208,7 +210,7 @@ class AlphaBetaProfiler:
return alpha_beta_dict
def search_best_logical_mesh(self):
'''
"""
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
......@@ -232,19 +234,19 @@ class AlphaBetaProfiler:
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
'''
"""
def _power_of_two(integer):
return integer & (integer - 1) == 0
def _detect_homogeneous_device(alpha_beta_dict):
'''
"""
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
'''
"""
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
for process_group, (_, beta) in alpha_beta_dict.items():
if homogeneous_device_dict is None:
......@@ -254,7 +256,8 @@ class AlphaBetaProfiler:
match_beta = None
for beta_value in homogeneous_device_dict.keys():
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
1 - self.homogeneous_tolerance):
1 - self.homogeneous_tolerance
):
match_beta = beta_value
break
......@@ -267,9 +270,9 @@ class AlphaBetaProfiler:
return homogeneous_device_dict
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
'''
"""
This function is used to check whether the homogeneous_group contains all physical devices.
'''
"""
flatten_mesh = []
for process_group in homogeneous_group:
flatten_mesh.extend(process_group)
......@@ -277,9 +280,9 @@ class AlphaBetaProfiler:
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
'''
"""
This function is used to construct the largest ring in the homogeneous_group for each rank.
'''
"""
# Construct the ring
ring = []
ranks_in_ring = []
......@@ -300,7 +303,9 @@ class AlphaBetaProfiler:
check_rank = check_rank_list.pop()
for process_group in homogeneous_group:
if check_rank in process_group:
rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
rank_to_append = (
process_group[0] if process_group[1] == check_rank else process_group[1]
)
if rank_to_append not in ring_for_rank:
stable_status = False
rank_to_check_list.append(rank_to_append)
......@@ -314,7 +319,7 @@ class AlphaBetaProfiler:
assert _power_of_two(self.world_size)
power_of_two = int(math.log2(self.world_size))
median = power_of_two // 2
balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median))
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
balanced_logical_mesh = []
for row_index in range(row_size):
......@@ -348,7 +353,7 @@ class AlphaBetaProfiler:
return best_logical_mesh
def extract_alpha_beta_for_device_mesh(self):
'''
"""
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
......@@ -360,7 +365,7 @@ class AlphaBetaProfiler:
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
'''
"""
best_logical_mesh = self.search_best_logical_mesh()
first_axis = [row[0] for row in best_logical_mesh]
......
......@@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
while i <= num_devices_per_host:
i *= 2
p += 1
assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
f"while now num_devices_per_host = {num_devices_per_host}")
assert pow(2, p) == num_devices_per_host, (
"Only supports the cases where num_devices_per_host is power of two, "
f"while now num_devices_per_host = {num_devices_per_host}"
)
if mode == "alpa":
for i in range(p + 1):
submesh_choices.append((1, pow(2, i)))
......@@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
return submesh_choices
def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
best_configs):
def alpa_dp_impl(
num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs
):
"""Implementation of Alpa DP for pipeline strategy
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
Arguments:
num_layers: K
num_devices: N*M
num_microbatches: B
submesh_choices: List[(n_i,m_i)]
compute_cost: t_intra
"""
Arguments:
num_layers: K
num_devices: N*M
num_microbatches: B
submesh_choices: List[(n_i,m_i)]
compute_cost: t_intra
"""
# For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
......@@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
for i in range(num_layers, k, -1):
stage_cost = compute_cost[k, i, m]
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
if stage_cost <= max_stage_cost and new_cost < f[s, k, d]:
f[s, k, d] = new_cost
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
......@@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
res = []
while current_s > 0 and current_layer < num_layers and current_devices > 0:
next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices]
assert next_start_layer != -1 and current_devices != -1
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
current_s -= 1
current_layer = next_start_layer
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
assert current_s == 0 and current_layer == num_layers and current_devices == 0
return total_cost, res
def alpa_dp(num_layers,
num_devices,
num_microbatches,
submesh_choices,
num_autosharding_configs,
compute_cost,
gap=1e-6):
def alpa_dp(
num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6
):
"""Alpa auto stage dynamic programming.
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments:
submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
num_autosharding_configs), "Cost shape wrong."
assert np.shape(compute_cost) == (
num_layers,
num_layers,
len(submesh_choices),
num_autosharding_configs,
), "Cost shape wrong."
all_possible_stage_costs = np.sort(np.unique(compute_cost))
best_cost = np.inf
best_solution = None
......@@ -117,8 +120,9 @@ def alpa_dp(num_layers,
break
if max_stage_cost - last_max_stage_cost < gap:
continue
cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
max_stage_cost, best_configs)
cost, solution = alpa_dp_impl(
num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs
)
if cost < best_cost:
best_cost = cost
best_solution = solution
......
......@@ -40,14 +40,16 @@ class DeviceMesh:
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
def __init__(self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
device: str = 'cuda'):
def __init__(
self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
device: str = "cuda",
):
# ============================
# Physical & Logical Mesh IDs
# ============================
......@@ -57,9 +59,10 @@ class DeviceMesh:
# logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, \
"Only one of mesh_shape and logical_mesh_id can be specified." \
assert mesh_shape is None or logical_mesh_id is None, (
"Only one of mesh_shape and logical_mesh_id can be specified."
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
)
if logical_mesh_id is None:
self._mesh_shape = mesh_shape
......@@ -71,12 +74,15 @@ class DeviceMesh:
# ensure two things:
# 1. logical and physical mesh IDs should contain the same elements
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
"Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
assert torch.equal(
torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert (
torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert (
torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# ===============================================
# coefficient for alpha-beta communication model
......@@ -92,8 +98,9 @@ class DeviceMesh:
self.mesh_beta = tuple(mesh_beta)
# ensure the alpha and beta have the same shape
assert len(self.mesh_alpha) == len(self.mesh_beta), \
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
assert len(self.mesh_alpha) == len(
self.mesh_beta
), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
# =========================
# Device for Process Group
......@@ -109,8 +116,9 @@ class DeviceMesh:
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
# }
self._global_to_local_rank_mapping = dict()
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
tensor=self.logical_mesh_id)
self._init_global_to_logical_rank_mapping(
mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
)
# create process group
self._process_group_dict = {}
......@@ -194,8 +202,9 @@ class DeviceMesh:
device_list = [_get_device_by_backend(pg) for pg in process_group]
# make sure all devices are the same
assert all([device == device_list[0] for device in device_list]), \
"All devices should be the same, please check your input process groups are created with the same distributed backend."
assert all(
[device == device_list[0] for device in device_list]
), "All devices should be the same, please check your input process groups are created with the same distributed backend."
# create a fake physical mesh id
# as we only get the process group associated with the current process,
......@@ -270,7 +279,7 @@ class DeviceMesh:
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != '_process_group_dict':
if k != "_process_group_dict":
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
# process group cannot be copied
......@@ -278,10 +287,9 @@ class DeviceMesh:
setattr(result, k, v)
return result
def _init_global_to_logical_rank_mapping(self,
mapping: Dict,
tensor: torch.Tensor,
index_list: List[int] = []) -> Dict[int, List[int]]:
def _init_global_to_logical_rank_mapping(
self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
) -> Dict[int, List[int]]:
"""
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
......@@ -311,15 +319,19 @@ class DeviceMesh:
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
def init_logical_process_group(self):
'''
"""
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
"""
# sanity check
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
assert (
dist.is_initialized
), "The torch.distributed should be initialized before calling init_logical_process_group"
assert (
not self._is_initialized
), "The logical process group has been initialized, do not call init_logical_process_group twice"
# update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank()
......@@ -389,7 +401,7 @@ class DeviceMesh:
return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank):
'''
"""
Give a global rank and return all global ranks involved in its associated process group in each axis.
Example:
......@@ -414,7 +426,7 @@ class DeviceMesh:
0: [0, 4, 8, 12],
1: [0, 1, 2, 3]
# }
'''
"""
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
# for self._global_to_local_rank_mapping
# the key is the global rank
......@@ -437,7 +449,6 @@ class DeviceMesh:
# in the same process group in the given axis
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):
# if this dimension is not initialized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
......@@ -478,29 +489,37 @@ class DeviceMesh:
flatten_mesh_shape_size = len(self._mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self._physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group)
return DeviceMesh(
self._physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group,
)
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.1)
return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
0.01)
return (
self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
+ 0.01
)
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.001)
return (
self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
)
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
return (
self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
+ 0.001
)
......@@ -2,16 +2,14 @@ from typing import Callable
import torch
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 12:
META_COMPATIBILITY = False
elif TORCH_MAJOR == 1 and TORCH_MINOR == 12:
from . import _meta_regist_12
META_COMPATIBILITY = True
elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
from . import _meta_regist_13
META_COMPATIBILITY = True
elif TORCH_MAJOR == 2:
META_COMPATIBILITY = True
......@@ -36,7 +34,7 @@ def compatibility(is_backward_compatible: bool = False) -> Callable:
else:
def wrapper(*args, **kwargs):
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}")
return wrapper
......
......@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Union
import torch
from torch.utils._pytree import tree_map
......@@ -16,13 +16,11 @@ meta_table = {}
def register_meta(op, register_dispatcher=True):
def wrapper(f):
def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
......@@ -48,7 +46,6 @@ def meta_conv(
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
......@@ -125,7 +122,8 @@ def meta_conv(
kernel_size[i],
stride[i],
output_padding_list[i],
))
)
)
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
......@@ -159,22 +157,42 @@ def meta_conv(
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
*extra_args):
def meta_conv_1(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args,
):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
def meta_conv_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta")
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
......@@ -208,7 +226,6 @@ def meta_cuda_rnn(
batch_sizes,
dropout_state,
):
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
......@@ -224,8 +241,11 @@ def meta_cuda_rnn(
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = ([mini_batch, seq_length, out_size *
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
out_shape = (
[mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
......@@ -242,18 +262,20 @@ def meta_cuda_rnn(
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs):
def meta_cudnn_rnn_backward(
input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs,
):
print(input, weight, hx, cx)
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_hx = torch.empty_like(hx)
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta')
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta")
return grad_input, grad_weight, grad_hx, grad_cx
......@@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
running_mean = torch.empty((n_input), device="meta")
running_var = torch.empty((n_input), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask):
def meta_bn_backward(
dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
train,
eps,
output_mask,
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
......@@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
reserve = torch.empty((0), dtype=torch.uint8, device='meta')
running_mean = torch.empty((n_input), device="meta")
running_var = torch.empty((n_input), device="meta")
reserve = torch.empty((0), dtype=torch.uint8, device="meta")
return output, running_mean, running_var, reserve
......@@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
def meta_cudnn_bn_backward(
dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
eps,
reserve,
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
......@@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((bs, n_input, 1), device='meta')
running_var = torch.empty((bs, n_input, 1), device='meta')
running_mean = torch.empty((bs, n_input, 1), device="meta")
running_var = torch.empty((bs, n_input, 1), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
def meta_ln_backward(
dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias)
......@@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
assert index.dtype in [
torch.long,
torch.int8,
torch.bool,
], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
assert (
index.shape[j] == self.shape[k + j]
), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
......@@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices):
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return torch.empty((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
def meta_embedding_dense_backward(
grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
):
return torch.empty(
(num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout,
)
# ============================== Dropout ===========================================
......
from typing import Any, Callable, Dict, Iterable, List, Tuple
from typing import Any, Dict, Iterable, List, Tuple
import torch
......@@ -18,6 +18,7 @@ try:
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import (
......@@ -32,12 +33,13 @@ except:
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = False
if CODEGEN_AVAILABLE:
__all__ = ['ActivationCheckpointCodeGen']
__all__ = ["ActivationCheckpointCodeGen"]
else:
__all__ = ['python_code_with_activation_checkpoint']
__all__ = ["python_code_with_activation_checkpoint"]
def _gen_saved_tensors_hooks():
......@@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]):
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_nodes = []
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
if 'activation_checkpoint' in node.meta:
act_ckpt_label = node.meta['activation_checkpoint']
if "activation_checkpoint" in node.meta:
act_ckpt_label = node.meta["activation_checkpoint"]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
......@@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and not 'activation_checkpoint' in node.meta:
elif current_region is not None and not "activation_checkpoint" in node.meta:
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
......@@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
act_offload_label = node.meta['activation_offload']
if "activation_offload" in node.meta and isinstance(node.meta["activation_offload"], Iterable):
act_offload_label = node.meta["activation_offload"]
if current_region == None:
current_region = act_offload_label
......@@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
"""
Generate the checkpoint function call code text
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
outputs = ", ".join(output_vars)
inputs = ", ".join(input_vars)
return f"{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
......@@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
Returns:
bool
"""
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], list):
return node.meta['activation_checkpoint'][check_idx] == None
if "activation_checkpoint" in node.meta:
if isinstance(node.meta["activation_checkpoint"], list):
return node.meta["activation_checkpoint"][check_idx] == None
else:
return False
else:
......@@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region = None
for idx, node in enumerate(nodes):
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], int):
act_ckpt_label = node.meta['activation_checkpoint']
if "activation_checkpoint" in node.meta:
if isinstance(node.meta["activation_checkpoint"], int):
act_ckpt_label = node.meta["activation_checkpoint"]
else:
act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
act_ckpt_label = node.meta["activation_checkpoint"][check_idx]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
......@@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
return ckpt_regions
def emit_ckpt_func(body,
ckpt_func,
node_list: List[Node],
emit_node_func,
delete_unused_value_func,
level=0,
in_ckpt=False):
def emit_ckpt_func(
body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, level=0, in_ckpt=False
):
"""Emit ckpt function in nested way
Args:
body: forward code, in recursive calls, this part will be checkpoint
......@@ -321,17 +318,17 @@ def emit_ckpt_func(body,
inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method
if isinstance(node_list[0].meta['activation_checkpoint'], int):
label = node_list[0].meta['activation_checkpoint']
if isinstance(node_list[0].meta["activation_checkpoint"], int):
label = node_list[0].meta["activation_checkpoint"]
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
ckpt_func.append(f"{ckpt_fn_def}\n")
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = node_list[0].meta.get('activation_offload', False)
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)
......@@ -340,12 +337,12 @@ def emit_ckpt_func(body,
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
label = "_".join([str(idx) for idx in node_list[0].meta["activation_checkpoint"][: level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch
if level + 1 < len(node_list[0].meta['activation_checkpoint']):
if level + 1 < len(node_list[0].meta["activation_checkpoint"]):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
......@@ -358,38 +355,45 @@ def emit_ckpt_func(body,
break
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
delete_unused_value_func, level + 1, True)
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(
ckpt_func,
ckpt_func_buffer,
ckpt_node_list,
emit_node_func,
delete_unused_value_func,
level + 1,
True,
)
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
usage = ' ' + usage
usage = " " + usage
body.append(usage)
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
usage = ' ' + usage
usage = " " + usage
body.append(usage)
......@@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
offload_node_list = node_list[start:end + 1]
offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
......@@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# process ckpt_regions
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
......@@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
if within_offload_region:
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
......@@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions):
ckpt_node_list = node_list[start:end + 1]
ckpt_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(ckpt_node_list)
input_vars.append(inputs)
output_vars.append(outputs)
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
offload_node_list = node_list[start:end + 1]
offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
......@@ -527,7 +531,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if idx in start_idx:
label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
ckpt_func.append(f'{ckpt_fn_def}\n')
ckpt_func.append(f"{ckpt_fn_def}\n")
within_ckpt_region = True
if idx in offload_starts:
......@@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# NOTE: currently we separate body and ckpt_func definition
if within_ckpt_region:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
elif within_offload_region:
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
......@@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate return statement
label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label])
return_statement = f' {return_statement}\n\n'
return_statement = f" {return_statement}\n\n"
ckpt_func.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
if 'activation_offload' in node_list[start_node_idx].meta:
activation_offload = node_list[start_node_idx].meta['activation_offload']
if "activation_offload" in node_list[start_node_idx].meta:
activation_offload = node_list[start_node_idx].meta["activation_offload"]
else:
activation_offload = False
......@@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if input_node.op != "placeholder":
non_leaf_input = 1
for user in input_node.users:
if 'activation_checkpoint' in user.meta:
if user.meta['activation_checkpoint'] == label:
if "activation_checkpoint" in user.meta:
if user.meta["activation_checkpoint"] == label:
if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
......@@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
usage += '\n'
usage += "\n"
body.append(usage)
within_ckpt_region = False
......@@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if CODEGEN_AVAILABLE:
class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
......@@ -629,7 +632,7 @@ if CODEGEN_AVAILABLE:
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = ['']
maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
......@@ -637,7 +640,7 @@ if CODEGEN_AVAILABLE:
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
......@@ -662,16 +665,16 @@ if CODEGEN_AVAILABLE:
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return '()'
return "()"
typename = _type_repr(o)
if hasattr(o, '__origin__'):
if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, '__args__'):
if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
......@@ -690,19 +693,18 @@ if CODEGEN_AVAILABLE:
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
......@@ -728,90 +730,101 @@ if CODEGEN_AVAILABLE:
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == 'placeholder':
if user.op == "placeholder":
return
if user.op == 'output':
body.append('\n')
if user.op == "output":
body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f"; {to_delete_str}\n")
else:
body.append('\n')
body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n')
body.append(f"{repr(node)} = {raw_name}\n")
return
elif node.op == 'call_method':
elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})"
)
return
elif node.op == 'call_function':
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
body.append(
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier() and \
len(node.args) == 2:
if (
global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return
body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
if node.meta.get('is_wrapped', False):
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
)
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == 'call_module':
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return
elif node.op == 'get_attr':
elif node.op == "get_attr":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
elif node.op == 'output':
elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
raise NotImplementedError(f'node: {node.op} {node.target}')
raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
......@@ -820,13 +833,13 @@ if CODEGEN_AVAILABLE:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append('pass\n')
body.append("pass\n")
if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
wrap_stmts = ''
wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
......@@ -837,11 +850,11 @@ if CODEGEN_AVAILABLE:
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = ''.join(ckpt_func) + prologue
prologue = "".join(ckpt_func) + prologue
prologue = prologue
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
......@@ -861,7 +874,7 @@ else:
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = ['']
maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
......@@ -869,7 +882,7 @@ else:
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
......@@ -894,12 +907,12 @@ else:
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return '()'
return "()"
typename = _type_repr(o)
# This is a generic type, e.g. typing.List[torch.Tensor]
if hasattr(o, '__origin__'):
if hasattr(o, "__origin__"):
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
......@@ -934,84 +947,94 @@ else:
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == 'placeholder':
if user.op == "placeholder":
return
if user.op == 'output':
body.append('\n')
if user.op == "output":
body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f"; {to_delete_str}\n")
else:
body.append('\n')
body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n')
body.append(f"{repr(node)} = {raw_name}\n")
return
elif node.op == 'call_method':
elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})"
)
return
elif node.op == 'call_function':
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier() and \
len(node.args) == 2:
if (
global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return
body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
if node.meta.get('is_wrapped', False):
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
)
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == 'call_module':
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return
elif node.op == 'get_attr':
elif node.op == "get_attr":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
elif node.op == 'output':
elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
if self._pytree_info is None:
body.append(f'return {repr(node.args[0])}')
body.append(f"return {repr(node.args[0])}")
else:
body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)')
body.append(f"return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)")
return
raise NotImplementedError(f'node: {node.op} {node.target}')
raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
......@@ -1020,33 +1043,34 @@ else:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append('pass\n')
body.append("pass\n")
if self._pytree_info is not None:
orig_args = self._pytree_info.orig_args
has_orig_self = (orig_args[0] == 'self')
has_orig_self = orig_args[0] == "self"
if has_orig_self:
free_vars.insert(0, 'self')
if len(free_vars) > 0: # pytree has placeholders in it
free_vars.insert(0, "self")
if len(free_vars) > 0: # pytree has placeholders in it
body.insert(
0,
f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n")
f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n",
)
else:
orig_args = free_vars
if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
wrap_stmts = ''
wrap_stmts = ""
ckpt_func = ''.join(ckpt_func)
ckpt_func = "".join(ckpt_func)
# If the original function didn't have self as its first argument, we
# would have added it.
if len(orig_args) == 0 or orig_args[0] != 'self':
orig_args.insert(0, 'self')
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
if len(orig_args) == 0 or orig_args[0] != "self":
orig_args.insert(0, "self")
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
......
import os
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, Union
from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
try:
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen
from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
COLOGM = True
except:
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
COLOGM = False
if COLOGM:
class ColoGraphModule(GraphModule):
def __init__(self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
class_name: str = 'GraphModule',
ckpt_codegen: bool = True):
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
class_name: str = "GraphModule",
ckpt_codegen: bool = True,
):
if ckpt_codegen:
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name)
......@@ -60,7 +63,7 @@ if COLOGM:
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
......@@ -83,8 +86,8 @@ if COLOGM:
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
if "_wrapped_call" not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
......@@ -108,7 +111,7 @@ if COLOGM:
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt')
torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
......@@ -125,7 +128,13 @@ class {module_name}(torch.nn.Module):
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [
nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
......@@ -136,10 +145,10 @@ class {module_name}(torch.nn.Module):
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
module_file = folder / f'{module_name}.pt'
module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
......@@ -156,19 +165,20 @@ class {module_name}(torch.nn.Module):
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py'
module_file = folder / "module.py"
module_file.write_text(model_str)
init_file = folder / '__init__.py'
init_file.write_text('from .module import *')
init_file = folder / "__init__.py"
init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")
warnings.warn(
"Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}"
)
else:
class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"):
super().__init__(root, graph, class_name)
import numpy as np
import torch
import tqdm
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
......@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
accumulate_bwd_flop = 0
block_nodes = []
for node in gm.graph.nodes:
if 'block_split' in node.name:
if "block_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
accumulate_bwd_flop += node.bwd_flop
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
with gm.graph.inserting_after(node):
block_node = gm.graph.create_node('call_function', block_split)
setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
block_node = gm.graph.create_node("call_function", block_split)
setattr(block_node, "fwd_flop", accumulate_fwd_flop)
setattr(block_node, "bwd_flop", accumulate_bwd_flop)
accumulate_fwd_flop = 0
accumulate_bwd_flop = 0
block_nodes.append(block_node)
......@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
def remove_blocks(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if (node.op, node.target) == ('call_function', block_split):
if (node.op, node.target) == ("call_function", block_split):
gm.graph.erase_node(node)
......@@ -55,8 +53,8 @@ def get_compute_costs(node_list):
num_nodes = len(node_list)
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0):
for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False):
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
all_compute_cost[start, end] = sum(selected_flops)
......@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost
# record start node index for next stage in this partition
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
f[0, num_nodes] = 0
for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
for s in tqdm.tqdm(
range(1, num_stages + 1), desc="stage", position=2, leave=False
): # pylint: disable=too-many-nested-blocks
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False):
for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False):
stage_cost = compute_costs[i, k - 1]
new_cost = f[s - 1, k] + stage_cost
if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
if stage_cost <= max_compute_cost and new_cost < f[s, i]:
f[s, i] = new_cost
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
f_argmin[s, i] = k
......@@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
best_cost = np.inf
best_solution = None
last_max_compute_cost = 0.0
gap = 1e6 # temporary magic number, unit: flops
gap = 1e6 # temporary magic number, unit: flops
for max_compute_cost in tqdm.tqdm(max_compute_costs):
# Pruning to reduce search space.
......@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
if max_compute_cost - last_max_compute_cost < gap:
continue
cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
max_compute_cost)
cost, solution = do_dp_split_gpipe_impl(
len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost
)
if cost < best_cost:
best_cost = cost
......@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
# split_mode:
# 'node': fx_node
# 'block': many fx_nodes construct a block
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
assert mode in ['node', 'block']
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01):
assert mode in ["node", "block"]
# nodes or blocks will be used in partition.
node_list = []
if mode == 'node':
if mode == "node":
for node in gm.graph.nodes:
node_list.append(node)
elif mode == 'block':
elif mode == "block":
node_list = construct_blocks(gm, limit=block_limit)
else:
pass
......@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
for (_, next_start_node) in best_solution:
for _, next_start_node in best_solution:
if pp_size <= 1:
break
node = node_list[next_start_node]
with gm.graph.inserting_before(node):
split_node = gm.graph.create_node('call_function', pipe_split)
split_node = gm.graph.create_node("call_function", pipe_split)
pp_size -= 1
# remove block node if possible
if mode == 'block':
if mode == "block":
remove_blocks(gm)
gm.recompile()
......@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta:
if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_fwd_flop = 0
......@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
if "pipe_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
if accumulate_fwd_flop >= partition_flop:
......@@ -199,7 +200,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_flop = total_fwd_flop // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
......@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
if node.next.op == 'output':
if node.next.op == "output":
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
......@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
# If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition.
if node.next.op == 'output':
if node.next.op == "output":
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
if pp_size > 1:
node_counter = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == 'placeholder':
if node.op == "placeholder":
continue
elif node_counter == 0:
node_counter += 1
......@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
node_counter = 0
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
......@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta:
if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_element_size = 0
......@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
if "pipe_split" in node.name:
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
......@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
......@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
accumulate_layer_amount = 0
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
......@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
def split_callback(n: torch.fx.Node):
nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split):
if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
......@@ -355,7 +356,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split):
if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
......
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.fx
......@@ -85,10 +85,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
self._is_proped = True
result, meta_info = super().run_node(n)
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
n.meta['type'] = type(result)
setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0))
n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
......@@ -98,7 +98,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
......@@ -119,7 +119,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
......@@ -138,7 +138,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
......@@ -157,7 +157,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_function(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
......@@ -175,7 +175,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_method(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
......@@ -197,7 +197,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_module(submod, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
......@@ -228,7 +228,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
"""
return self.run(*args)
def summary(self, unit: str = 'MB') -> str:
def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
......@@ -238,9 +238,11 @@ class ConcreteInfoProp(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
print(
"`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
......@@ -249,10 +251,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
"kb": 1024,
"mb": 1024**2,
"gb": 1024**3,
"tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
......@@ -261,30 +263,32 @@ class ConcreteInfoProp(torch.fx.Interpreter):
for node in self.module.graph.nodes:
node: Node
node_summaries.append([
node.op,
str(node),
time_repr(node.meta['fwd_time']),
time_repr(node.meta['bwd_time']),
node.meta['save_fwd_in'],
mem_repr(node.meta['fwd_mem_out']),
mem_repr(node.meta['fwd_mem_tmp']),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
node_summaries.append(
[
node.op,
str(node),
time_repr(node.meta["fwd_time"]),
time_repr(node.meta["bwd_time"]),
node.meta["save_fwd_in"],
mem_repr(node.meta["fwd_mem_out"]),
mem_repr(node.meta["fwd_mem_tmp"]),
mem_repr(node.meta["bwd_mem_out"]),
mem_repr(node.meta["bwd_mem_tmp"]),
]
)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward time',
'Backward time',
'SAVE_FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
"Op type",
"Op",
"Forward time",
"Backward time",
"SAVE_FWD_IN",
"FWD_OUT",
"FWD_TMP",
"BWD_OUT",
"BWD_TMP",
]
return tabulate(node_summaries, headers=headers, stralign='right')
return tabulate(node_summaries, headers=headers, stralign="right")
import torch
from typing import List
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
from copy import deepcopy
from typing import List
import torch
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
def apply(*args, **kwargs):
......@@ -24,16 +21,16 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
setattr(node, 'best_strategy', strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec)
setattr(node, "best_strategy", strategies_vector[strategy_index])
setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec)
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec
# apply the sharding spec of parameters
for node in nodes:
if node.op == 'call_module':
if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
setattr(target_module.weight, 'sharding_spec', origin_sharding_spec)
setattr(target_module.weight, "sharding_spec", origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.input_shardings[1]
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
apply(target_module.weight, target_weight_sharding_spec)
......@@ -51,10 +48,10 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
if node.op != "placeholder":
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
break
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
......@@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
node_to_index_dict = {}
index = 0
for node in nodes:
if node.target == 'sharding_spec_convert_dict':
if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
if node.target == 'origin_node_sharding_spec_dict':
if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
......@@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
# add shape consistency apply function into graph
for node in nodes:
if not hasattr(node, 'best_strategy'):
if not hasattr(node, "best_strategy"):
continue
with mod_graph.inserting_after(node):
origin_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(origin_dict_node, node_to_index_dict[node]))
origin_spec_node = mod_graph.create_node(
"call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node])
)
with mod_graph.inserting_after(origin_spec_node):
set_sharding_spec_node = mod_graph.create_node('call_function',
builtins.setattr,
args=(node, 'sharding_spec', origin_spec_node))
set_sharding_spec_node = mod_graph.create_node(
"call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node)
)
for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node):
input_specs_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_dict_node, node_to_index_dict[node]))
input_specs_node = mod_graph.create_node(
"call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node])
)
with mod_graph.inserting_before(user_node):
sharding_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_specs_node, node_index))
sharding_spec_node = mod_graph.create_node(
"call_function", operator.getitem, args=(input_specs_node, node_index)
)
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))
shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node))
return gm
......@@ -109,13 +109,13 @@ class MetaInfoProp(torch.fx.Interpreter):
return TensorMetadata(None, None, False, None, 0, False)
tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
n.meta["tensor_meta"] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
n.meta['type'] = type(result)
setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0)))
setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0))
setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0))
n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
......@@ -125,7 +125,7 @@ class MetaInfoProp(torch.fx.Interpreter):
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
......@@ -146,7 +146,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
......@@ -165,7 +165,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
......@@ -184,7 +184,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_function(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
......@@ -202,7 +202,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_method(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
......@@ -224,7 +224,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_module(submod)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
......@@ -240,7 +240,7 @@ class MetaInfoProp(torch.fx.Interpreter):
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
if hasattr(args[0], '_tensor'):
if hasattr(args[0], "_tensor"):
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True)
......@@ -257,7 +257,7 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
return super().run(*args)
def summary(self, unit: str = 'MB') -> str:
def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
......@@ -267,9 +267,11 @@ class MetaInfoProp(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
print(
"`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
......@@ -278,10 +280,10 @@ class MetaInfoProp(torch.fx.Interpreter):
def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
"kb": 1024,
"mb": 1024**2,
"gb": 1024**3,
"tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
......@@ -292,35 +294,37 @@ class MetaInfoProp(torch.fx.Interpreter):
for node in self.module.graph.nodes:
node: Node
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
node_summaries.append([
node.op,
str(node),
flops_repr(node.meta['fwd_flop']),
flops_repr(node.meta['bwd_flop']),
mem_repr(accumulate_size),
mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_out(node)),
mem_repr(calculate_fwd_tmp(node)),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
node_summaries.append(
[
node.op,
str(node),
flops_repr(node.meta["fwd_flop"]),
flops_repr(node.meta["bwd_flop"]),
mem_repr(accumulate_size),
mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_out(node)),
mem_repr(calculate_fwd_tmp(node)),
mem_repr(node.meta["bwd_mem_out"]),
mem_repr(node.meta["bwd_mem_tmp"]),
]
)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward FLOPs',
'Backward FLOPs',
'Accumulated Memory',
'FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
"Op type",
"Op",
"Forward FLOPs",
"Backward FLOPs",
"Accumulated Memory",
"FWD_IN",
"FWD_OUT",
"FWD_TMP",
"BWD_OUT",
"BWD_TMP",
]
return tabulate(node_summaries, headers=headers, stralign='right')
return tabulate(node_summaries, headers=headers, stralign="right")
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
......@@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
interp.propagate(*args, **kwargs)
if verbose:
interp.summary(unit)
gm.to('cpu')
gm.to("cpu")
del interp
return gm
......@@ -5,7 +5,6 @@ import torch
from packaging import version
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split
from colossalai.fx.passes.meta_info_prop import TensorMetadata
......@@ -13,9 +12,9 @@ from colossalai.fx.passes.split_module import Partition
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
'''
"""
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
'''
"""
mod_graph = gm.graph
valid_children_size = 0
valid_children = []
......@@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti
part_index += 1
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
'''
"""
This pass will be used in gpt2 test, only a part of changes may be added into
split_with_split_nodes_pass, and it will be deprecated in future.
'''
"""
part_idx = 0
def eliminate_unused_placeholders(gm):
for node in gm.graph.nodes:
if node.op == 'placeholder':
if node.op == "placeholder":
if not len(node.users):
gm.graph.erase_node(node)
gm.recompile()
return gm
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
'''
"""
This method is used to eliminate the outputs in previous partition which is unused in next partition.
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
'''
"""
output_type = None
output_args = []
non_output_list = []
new_placeholder_list = []
for node in gm.graph.nodes:
if node.op == 'output':
if node.op == "output":
if isinstance(node.args[0], (tuple, list)):
output_type = node.args[0].__class__
output_args.extend([n.name for n in node.args[0]])
......@@ -114,7 +113,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
continue
for node in gm.graph.nodes:
if node.op == 'placeholder':
if node.op == "placeholder":
new_placeholder_list.append(node.name)
if output_type is not None:
gm.graph.output(output_type(output_args))
......@@ -125,7 +124,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
def split_callback(n: torch.fx.Node):
nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split):
if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
......@@ -134,7 +133,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split):
if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
......@@ -200,13 +199,12 @@ def split_module_for_gpt2_test(
_gen_all_ancestors_set(node)
for n in list(all_ancestors):
if n.op != 'placeholder' and n._fx_partition > partition_name:
if n.op != "placeholder" and n._fx_partition > partition_name:
n._fx_partition = partition_name
def record_cross_partition_use(def_node: torch.fx.node.Node,
use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
# if 'tensor_meta' in def_node.meta:
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
......@@ -237,7 +235,7 @@ def split_module_for_gpt2_test(
if node.op in ["placeholder"]:
continue
if node.op == 'output':
if node.op == "output":
# partition_name = str(split_callback(node))
# def _set_output_args_partition(n, partition_name):
# n._fx_partition = partition_name
......@@ -252,12 +250,12 @@ def split_module_for_gpt2_test(
partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name)
origin_partition_name = getattr(node, '_fx_partition', None)
origin_partition_name = getattr(node, "_fx_partition", None)
if origin_partition_name is None:
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
......@@ -287,7 +285,7 @@ def split_module_for_gpt2_test(
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, '_fx_partition'):
if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
......@@ -295,26 +293,24 @@ def split_module_for_gpt2_test(
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
if node.op not in ['call_module', 'get_attr']:
if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
target_atoms = node.target.split('.')
target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = '_'.join(target_atoms)
target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs,
name=node.name)
new_node = partition.graph.create_node(
op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name
)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
......@@ -323,14 +319,14 @@ def split_module_for_gpt2_test(
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
if node.op == "placeholder":
if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name] = base_mod_graph.placeholder(
node.name, type_expr=node.type, default_value=default_value
)
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
......@@ -344,13 +340,14 @@ def split_module_for_gpt2_test(
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
partition.graph) # noqa: B950
submod_name = f"submod_{partition_name}"
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
partition.targets, partition.graph
) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
......@@ -358,14 +355,14 @@ def split_module_for_gpt2_test(
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
if node.op == "output":
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
......@@ -9,8 +9,19 @@ from colossalai.legacy.tensor.distspec import ShardSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
torch.add,
operator.add,
torch.abs,
torch.cos,
torch.exp,
torch.mul,
operator.mul,
operator.floordiv,
operator.truediv,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
]
......@@ -72,7 +83,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# traverse the graph to look for consecutive linear layers
is_linear_module = False
if node.op == 'call_module':
if node.op == "call_module":
# look for the linear layer
module = node.graph.owning_module.get_submodule(node.target)
if isinstance(module, nn.Linear):
......@@ -82,31 +93,31 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# it means the first linear has been found and the current module
# is the second linear
# set the current linear module to be row-sharded
annotation_record['row'] = module
annotation_record["row"] = module
for shard_type, module in annotation_record.items():
# add row sharding spec
if shard_type == 'row':
if shard_type == "row":
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', dist_spec)
setattr(module.weight, 'comp_spec', comp_spec)
elif shard_type == 'col':
setattr(module.weight, "pg", process_group)
setattr(module.weight, "dist_spec", dist_spec)
setattr(module.weight, "comp_spec", comp_spec)
elif shard_type == "col":
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', weight_dist_spec)
setattr(module.weight, 'comp_spec', weight_comp_spec)
setattr(module.weight, "pg", process_group)
setattr(module.weight, "dist_spec", weight_dist_spec)
setattr(module.weight, "comp_spec", weight_comp_spec)
if module.bias is not None:
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
setattr(module.bias, 'pg', process_group)
setattr(module.bias, 'dist_spec', bias_dist_spec)
setattr(module.bias, 'comp_spec', bias_comp_spec)
setattr(module.bias, "pg", process_group)
setattr(module.bias, "dist_spec", bias_dist_spec)
setattr(module.bias, "comp_spec", bias_comp_spec)
start_tracking = False
annotation_record.clear()
else:
......@@ -114,16 +125,16 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# it means the current layer is the first linear
# set the linear layer to be col-sharded
start_tracking = True
annotation_record['col'] = module
annotation_record["col"] = module
if start_tracking and not is_linear_module:
# check against the white list
# if non-element wise op is found, we reset the tracking
if node.op == 'call_module':
if node.op == "call_module":
module = node.graph.owning_module.get_submodule(node.target)
if module.__class__ not in ELEMENTWISE_MODULE_OP:
start_tracking = False
elif node.op == 'call_function' or node.op == 'call_method':
elif node.op == "call_function" or node.op == "call_method":
if node.target not in ELEMENTWISE_FUNC_OP:
start_tracking = False
elif len(node.users.keys()) > 1:
......
......@@ -25,12 +25,14 @@ class Partition:
self.targets: Dict[str, Any] = {}
def __repr__(self) -> str:
return f"name: {self.name},\n" \
f" nodes: {self.node_names},\n" \
f" inputs: {self.inputs},\n" \
f" outputs: {self.outputs},\n" \
f" partitions dependent on: {self.partitions_dependent_on},\n" \
return (
f"name: {self.name},\n"
f" nodes: {self.node_names},\n"
f" inputs: {self.inputs},\n"
f" outputs: {self.outputs},\n"
f" partitions dependent on: {self.partitions_dependent_on},\n"
f" partition dependents: {self.partition_dependents}"
)
# Creates subgraphs out of main graph
......@@ -117,10 +119,9 @@ def split_module(
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}
def record_cross_partition_use(def_node: torch.fx.node.Node,
use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
......@@ -134,7 +135,7 @@ def split_module(
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
......@@ -161,7 +162,7 @@ def split_module(
if node.op in ["placeholder"]:
continue
if node.op == 'output':
if node.op == "output":
if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else:
......@@ -178,7 +179,7 @@ def split_module(
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
......@@ -208,7 +209,7 @@ def split_module(
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, '_fx_partition'):
if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
......@@ -216,25 +217,24 @@ def split_module(
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
if node.op not in ['call_module', 'get_attr']:
if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
target_atoms = node.target.split('.')
target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = '_'.join(target_atoms)
target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs)
new_node = partition.graph.create_node(
op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs
)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
......@@ -243,14 +243,14 @@ def split_module(
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
if node.op == "placeholder":
if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name] = base_mod_graph.placeholder(
node.target, type_expr=node.type, default_value=default_value
)
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
......@@ -264,13 +264,14 @@ def split_module(
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
partition.graph) # noqa: B950
submod_name = f"submod_{partition_name}"
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
partition.targets, partition.graph
) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
......@@ -278,15 +279,15 @@ def split_module(
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
if node.op == "output":
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
for partition_name in sorted_partitions:
partition = partitions[partition_name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment