Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -94,7 +94,7 @@ class ProcessGroupMesh: ...@@ -94,7 +94,7 @@ class ProcessGroupMesh:
return np.unravel_index(rank, shape) return np.unravel_index(rank, shape)
@staticmethod @staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int: def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
"""Convert a coordinate to a rank. """Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html. mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around. with wrap, index out of range would be wrapped around.
...@@ -141,8 +141,9 @@ class ProcessGroupMesh: ...@@ -141,8 +141,9 @@ class ProcessGroupMesh:
return list(self._group_to_ranks[group]) return list(self._group_to_ranks[group])
@staticmethod @staticmethod
def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int, def get_coords_along_axis(
indices_at_axis: List[int]) -> List[Tuple[int, ...]]: base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis. """Get coordinates along the given axis.
Args: Args:
...@@ -155,13 +156,12 @@ class ProcessGroupMesh: ...@@ -155,13 +156,12 @@ class ProcessGroupMesh:
""" """
coords_in_group = [] coords_in_group = []
for idx in indices_at_axis: for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:]) coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group return coords_in_group
def create_group_along_axis(self, def create_group_along_axis(
axis: int, self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
indices_at_axis: Optional[List[int]] = None, ) -> ProcessGroup:
backend: Optional[str] = None) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to. """Create all process groups along the given axis, and return the one which the current process belongs to.
Args: Args:
...@@ -186,10 +186,9 @@ class ProcessGroupMesh: ...@@ -186,10 +186,9 @@ class ProcessGroupMesh:
target_group = group target_group = group
return target_group return target_group
def get_group_along_axis(self, def get_group_along_axis(
axis: int, self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
indices_at_axis: Optional[List[int]] = None, ) -> ProcessGroup:
backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args: Args:
......
...@@ -3,6 +3,6 @@ from .config import Config, ConfigException ...@@ -3,6 +3,6 @@ from .config import Config, ConfigException
# from .moe_context import MOE_CONTEXT # from .moe_context import MOE_CONTEXT
__all__ = [ __all__ = [
'Config', "Config",
'ConfigException', "ConfigException",
] ]
...@@ -5,6 +5,7 @@ import inspect ...@@ -5,6 +5,7 @@ import inspect
import sys import sys
from importlib.machinery import SourceFileLoader from importlib.machinery import SourceFileLoader
from pathlib import Path from pathlib import Path
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
...@@ -41,7 +42,7 @@ class Config(dict): ...@@ -41,7 +42,7 @@ class Config(dict):
self.__setattr__(key, value) self.__setattr__(key, value)
def update(self, config): def update(self, config):
assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.' assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
for k, v in config.items(): for k, v in config.items():
self._add_item(k, v) self._add_item(k, v)
return self return self
...@@ -66,11 +67,11 @@ class Config(dict): ...@@ -66,11 +67,11 @@ class Config(dict):
elif isinstance(filename, Path): elif isinstance(filename, Path):
filepath = filename.absolute() filepath = filename.absolute()
assert filepath.exists(), f'{filename} is not found, please check your configuration path' assert filepath.exists(), f"{filename} is not found, please check your configuration path"
# check extension # check extension
extension = filepath.suffix extension = filepath.suffix
assert extension == '.py', 'only .py files are supported' assert extension == ".py", "only .py files are supported"
# import the config as module # import the config as module
remove_path = False remove_path = False
...@@ -86,13 +87,13 @@ class Config(dict): ...@@ -86,13 +87,13 @@ class Config(dict):
config = Config() config = Config()
for k, v in module.__dict__.items(): for k, v in module.__dict__.items():
if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v): if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
continue continue
else: else:
config._add_item(k, v) config._add_item(k, v)
logger = get_dist_logger() logger = get_dist_logger()
logger.debug('variables which starts with __, is a module or class declaration are omitted in config file') logger.debug("variables which starts with __, is a module or class declaration are omitted in config file")
# remove module # remove module
del sys.modules[module_name] del sys.modules[module_name]
......
...@@ -9,14 +9,13 @@ from colossalai.legacy.tensor import ProcessGroup ...@@ -9,14 +9,13 @@ from colossalai.legacy.tensor import ProcessGroup
def _check_sanity(): def _check_sanity():
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or " raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
"pipeline parallel at present.")
class MoeParallelInfo: class MoeParallelInfo:
"""Moe parallelism information, storing parallel sizes and groups. """Moe parallelism information, storing parallel sizes and groups."""
"""
def __init__(self, ep_size: int, dp_size: int): def __init__(self, ep_size: int, dp_size: int):
_check_sanity() _check_sanity()
...@@ -61,9 +60,11 @@ class MoeContext(metaclass=SingletonMeta): ...@@ -61,9 +60,11 @@ class MoeContext(metaclass=SingletonMeta):
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
assert self.world_size % self.max_ep_size == 0, \ self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
"Maximum expert parallel size must be a factor of the number of GPUs" assert (
self.world_size % self.max_ep_size == 0
), "Maximum expert parallel size must be a factor of the number of GPUs"
self.min_dp_size = self.world_size // self.max_ep_size self.min_dp_size = self.world_size // self.max_ep_size
# Enabling kernel optimization may raise error in some cases # Enabling kernel optimization may raise error in some cases
...@@ -71,6 +72,7 @@ class MoeContext(metaclass=SingletonMeta): ...@@ -71,6 +72,7 @@ class MoeContext(metaclass=SingletonMeta):
self.use_kernel_optim = use_kernel_optim self.use_kernel_optim = use_kernel_optim
from .random import moe_set_seed from .random import moe_set_seed
moe_set_seed(seed) moe_set_seed(seed)
self.has_setup = True self.has_setup = True
...@@ -88,11 +90,13 @@ class MoeContext(metaclass=SingletonMeta): ...@@ -88,11 +90,13 @@ class MoeContext(metaclass=SingletonMeta):
number of local experts, the MoeParallelInfo of the current ep_size number of local experts, the MoeParallelInfo of the current ep_size
""" """
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ assert gt_flag or lt_flag, (
" is not a multiple of ep size or vice versa." "Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts # there are multiple experts in each GPU and each GPU has different experts
......
...@@ -16,6 +16,7 @@ class SingletonMeta(type): ...@@ -16,6 +16,7 @@ class SingletonMeta(type):
instance = super().__call__(*args, **kwargs) instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance cls._instances[cls] = instance
else: else:
assert len(args) == 0 and len( assert (
kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.' len(args) == 0 and len(kwargs) == 0
), f"{cls.__name__} is a singleton class and a instance has been created."
return cls._instances[cls] return cls._instances[cls]
from .alpha_beta_profiler import AlphaBetaProfiler from .alpha_beta_profiler import AlphaBetaProfiler
from .calc_pipeline_strategy import alpa_dp from .calc_pipeline_strategy import alpa_dp
__all__ = ['AlphaBetaProfiler', 'alpa_dp'] __all__ = ["AlphaBetaProfiler", "alpa_dp"]
...@@ -13,7 +13,7 @@ FRAMEWORK_LATENCY = 0 ...@@ -13,7 +13,7 @@ FRAMEWORK_LATENCY = 0
class AlphaBetaProfiler: class AlphaBetaProfiler:
''' """
Profile alpha and beta value for a given device list. Profile alpha and beta value for a given device list.
Usage: Usage:
...@@ -27,17 +27,19 @@ class AlphaBetaProfiler: ...@@ -27,17 +27,19 @@ class AlphaBetaProfiler:
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12), (1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11), (1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)} (4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
''' """
def __init__(self, def __init__(
physical_devices: List[int], self,
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, physical_devices: List[int],
ctype: str = 'a', alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
warmup: int = 5, ctype: str = "a",
repeat: int = 25, warmup: int = 5,
latency_iters: int = 5, repeat: int = 25,
homogeneous_tolerance: float = 0.1): latency_iters: int = 5,
''' homogeneous_tolerance: float = 0.1,
):
"""
Args: Args:
physical_devices: A list of device id, each element inside it is the global rank of that device. physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs. alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
...@@ -45,7 +47,7 @@ class AlphaBetaProfiler: ...@@ -45,7 +47,7 @@ class AlphaBetaProfiler:
warmup: Number of warmup iterations. warmup: Number of warmup iterations.
repeat: Number of iterations to measure. repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency. latency_iters: Number of iterations to measure latency.
''' """
self.physical_devices = physical_devices self.physical_devices = physical_devices
self.ctype = ctype self.ctype = ctype
self.world_size = len(physical_devices) self.world_size = len(physical_devices)
...@@ -123,7 +125,7 @@ class AlphaBetaProfiler: ...@@ -123,7 +125,7 @@ class AlphaBetaProfiler:
return (None, None) return (None, None)
def profile_latency(self, process_group, pg_handler): def profile_latency(self, process_group, pg_handler):
''' """
This function is used to profile the latency of the given process group with a series of bytes. This function is used to profile the latency of the given process group with a series of bytes.
Args: Args:
...@@ -132,7 +134,7 @@ class AlphaBetaProfiler: ...@@ -132,7 +134,7 @@ class AlphaBetaProfiler:
Returns: Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list. latency: None if the latency is not measured, otherwise the median of the latency_list.
''' """
latency_list = [] latency_list = []
for i in range(self.latency_iters): for i in range(self.latency_iters):
nbytes = int(BYTE << i) nbytes = int(BYTE << i)
...@@ -148,26 +150,26 @@ class AlphaBetaProfiler: ...@@ -148,26 +150,26 @@ class AlphaBetaProfiler:
return latency return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)): def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
''' """
This function is used to profile the bandwidth of the given process group. This function is used to profile the bandwidth of the given process group.
Args: Args:
process_group: A tuple of global rank of the process group. process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group. pg_handler: The handler of the process group.
''' """
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes) (_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
return bandwidth return bandwidth
def profile_ab(self): def profile_ab(self):
''' """
This method is used to profiling the alpha and beta value for a given device list. This method is used to profiling the alpha and beta value for a given device list.
Returns: Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value. alpha_beta_dict: A dict which maps process group to its alpha and beta value.
''' """
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {} alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank() rank = dist.get_rank()
global_pg_handler = dist.new_group(self.physical_devices) dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group assert rank in process_group
...@@ -208,7 +210,7 @@ class AlphaBetaProfiler: ...@@ -208,7 +210,7 @@ class AlphaBetaProfiler:
return alpha_beta_dict return alpha_beta_dict
def search_best_logical_mesh(self): def search_best_logical_mesh(self):
''' """
This method is used to search the best logical mesh for the given device list. This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps: The best logical mesh is searched in following steps:
...@@ -232,19 +234,19 @@ class AlphaBetaProfiler: ...@@ -232,19 +234,19 @@ class AlphaBetaProfiler:
>>> best_logical_mesh = profiler.search_best_logical_mesh() >>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh) >>> print(best_logical_mesh)
[[0, 1], [2, 3]] [[0, 1], [2, 3]]
''' """
def _power_of_two(integer): def _power_of_two(integer):
return integer & (integer - 1) == 0 return integer & (integer - 1) == 0
def _detect_homogeneous_device(alpha_beta_dict): def _detect_homogeneous_device(alpha_beta_dict):
''' """
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous. This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)] of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta. * base_beta.
''' """
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {} homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
for process_group, (_, beta) in alpha_beta_dict.items(): for process_group, (_, beta) in alpha_beta_dict.items():
if homogeneous_device_dict is None: if homogeneous_device_dict is None:
...@@ -254,7 +256,8 @@ class AlphaBetaProfiler: ...@@ -254,7 +256,8 @@ class AlphaBetaProfiler:
match_beta = None match_beta = None
for beta_value in homogeneous_device_dict.keys(): for beta_value in homogeneous_device_dict.keys():
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * ( if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
1 - self.homogeneous_tolerance): 1 - self.homogeneous_tolerance
):
match_beta = beta_value match_beta = beta_value
break break
...@@ -267,9 +270,9 @@ class AlphaBetaProfiler: ...@@ -267,9 +270,9 @@ class AlphaBetaProfiler:
return homogeneous_device_dict return homogeneous_device_dict
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
''' """
This function is used to check whether the homogeneous_group contains all physical devices. This function is used to check whether the homogeneous_group contains all physical devices.
''' """
flatten_mesh = [] flatten_mesh = []
for process_group in homogeneous_group: for process_group in homogeneous_group:
flatten_mesh.extend(process_group) flatten_mesh.extend(process_group)
...@@ -277,9 +280,9 @@ class AlphaBetaProfiler: ...@@ -277,9 +280,9 @@ class AlphaBetaProfiler:
return len(non_duplicated_flatten_mesh) == len(self.physical_devices) return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
''' """
This function is used to construct the largest ring in the homogeneous_group for each rank. This function is used to construct the largest ring in the homogeneous_group for each rank.
''' """
# Construct the ring # Construct the ring
ring = [] ring = []
ranks_in_ring = [] ranks_in_ring = []
...@@ -300,7 +303,9 @@ class AlphaBetaProfiler: ...@@ -300,7 +303,9 @@ class AlphaBetaProfiler:
check_rank = check_rank_list.pop() check_rank = check_rank_list.pop()
for process_group in homogeneous_group: for process_group in homogeneous_group:
if check_rank in process_group: if check_rank in process_group:
rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1] rank_to_append = (
process_group[0] if process_group[1] == check_rank else process_group[1]
)
if rank_to_append not in ring_for_rank: if rank_to_append not in ring_for_rank:
stable_status = False stable_status = False
rank_to_check_list.append(rank_to_append) rank_to_check_list.append(rank_to_append)
...@@ -314,7 +319,7 @@ class AlphaBetaProfiler: ...@@ -314,7 +319,7 @@ class AlphaBetaProfiler:
assert _power_of_two(self.world_size) assert _power_of_two(self.world_size)
power_of_two = int(math.log2(self.world_size)) power_of_two = int(math.log2(self.world_size))
median = power_of_two // 2 median = power_of_two // 2
balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median)) balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median))
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1] row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
balanced_logical_mesh = [] balanced_logical_mesh = []
for row_index in range(row_size): for row_index in range(row_size):
...@@ -348,7 +353,7 @@ class AlphaBetaProfiler: ...@@ -348,7 +353,7 @@ class AlphaBetaProfiler:
return best_logical_mesh return best_logical_mesh
def extract_alpha_beta_for_device_mesh(self): def extract_alpha_beta_for_device_mesh(self):
''' """
Extract the mesh_alpha list and mesh_beta list based on the Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh. best logical mesh, which will be used to initialize the device mesh.
...@@ -360,7 +365,7 @@ class AlphaBetaProfiler: ...@@ -360,7 +365,7 @@ class AlphaBetaProfiler:
[2.5917552411556242e-05, 0.00010312341153621673] [2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta) >>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12] [5.875573704655635e-11, 4.7361584445959614e-12]
''' """
best_logical_mesh = self.search_best_logical_mesh() best_logical_mesh = self.search_best_logical_mesh()
first_axis = [row[0] for row in best_logical_mesh] first_axis = [row[0] for row in best_logical_mesh]
......
...@@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): ...@@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
while i <= num_devices_per_host: while i <= num_devices_per_host:
i *= 2 i *= 2
p += 1 p += 1
assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, " assert pow(2, p) == num_devices_per_host, (
f"while now num_devices_per_host = {num_devices_per_host}") "Only supports the cases where num_devices_per_host is power of two, "
f"while now num_devices_per_host = {num_devices_per_host}"
)
if mode == "alpa": if mode == "alpa":
for i in range(p + 1): for i in range(p + 1):
submesh_choices.append((1, pow(2, i))) submesh_choices.append((1, pow(2, i)))
...@@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): ...@@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
return submesh_choices return submesh_choices
def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, def alpa_dp_impl(
best_configs): num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs
):
"""Implementation of Alpa DP for pipeline strategy """Implementation of Alpa DP for pipeline strategy
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
Arguments: Arguments:
num_layers: K num_layers: K
num_devices: N*M num_devices: N*M
num_microbatches: B num_microbatches: B
submesh_choices: List[(n_i,m_i)] submesh_choices: List[(n_i,m_i)]
compute_cost: t_intra compute_cost: t_intra
""" """
# For f, layer ID start from 0 # For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used] # f[#pipeline stages, layer id that is currently being considered, number of devices used]
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
...@@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com ...@@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
for i in range(num_layers, k, -1): for i in range(num_layers, k, -1):
stage_cost = compute_cost[k, i, m] stage_cost = compute_cost[k, i, m]
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]): if stage_cost <= max_stage_cost and new_cost < f[s, k, d]:
f[s, k, d] = new_cost f[s, k, d] = new_cost
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices]) f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
f_argmin[s, k, d] = (i, m, best_configs[k, i, m]) f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
...@@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com ...@@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
res = [] res = []
while current_s > 0 and current_layer < num_layers and current_devices > 0: while current_s > 0 and current_layer < num_layers and current_devices > 0:
next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices]) next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices]
assert next_start_layer != -1 and current_devices != -1 assert next_start_layer != -1 and current_devices != -1
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
current_s -= 1 current_s -= 1
current_layer = next_start_layer current_layer = next_start_layer
current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
assert (current_s == 0 and current_layer == num_layers and current_devices == 0) assert current_s == 0 and current_layer == num_layers and current_devices == 0
return total_cost, res return total_cost, res
def alpa_dp(num_layers, def alpa_dp(
num_devices, num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6
num_microbatches, ):
submesh_choices,
num_autosharding_configs,
compute_cost,
gap=1e-6):
"""Alpa auto stage dynamic programming. """Alpa auto stage dynamic programming.
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments: Arguments:
submesh_choices: List[(int,int)] submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh) num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs) compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
""" """
assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices), assert np.shape(compute_cost) == (
num_autosharding_configs), "Cost shape wrong." num_layers,
num_layers,
len(submesh_choices),
num_autosharding_configs,
), "Cost shape wrong."
all_possible_stage_costs = np.sort(np.unique(compute_cost)) all_possible_stage_costs = np.sort(np.unique(compute_cost))
best_cost = np.inf best_cost = np.inf
best_solution = None best_solution = None
...@@ -117,8 +120,9 @@ def alpa_dp(num_layers, ...@@ -117,8 +120,9 @@ def alpa_dp(num_layers,
break break
if max_stage_cost - last_max_stage_cost < gap: if max_stage_cost - last_max_stage_cost < gap:
continue continue
cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, cost, solution = alpa_dp_impl(
max_stage_cost, best_configs) num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs
)
if cost < best_cost: if cost < best_cost:
best_cost = cost best_cost = cost
best_solution = solution best_solution = solution
......
...@@ -40,14 +40,16 @@ class DeviceMesh: ...@@ -40,14 +40,16 @@ class DeviceMesh:
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
def __init__(self, def __init__(
physical_mesh_id: torch.Tensor, self,
mesh_shape: torch.Size = None, physical_mesh_id: torch.Tensor,
logical_mesh_id: torch.Tensor = None, mesh_shape: torch.Size = None,
mesh_alpha: List[float] = None, logical_mesh_id: torch.Tensor = None,
mesh_beta: List[float] = None, mesh_alpha: List[float] = None,
init_process_group: bool = False, mesh_beta: List[float] = None,
device: str = 'cuda'): init_process_group: bool = False,
device: str = "cuda",
):
# ============================ # ============================
# Physical & Logical Mesh IDs # Physical & Logical Mesh IDs
# ============================ # ============================
...@@ -57,9 +59,10 @@ class DeviceMesh: ...@@ -57,9 +59,10 @@ class DeviceMesh:
# logical mesh ids can be obtained via two ways # logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape # 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id # 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, \ assert mesh_shape is None or logical_mesh_id is None, (
"Only one of mesh_shape and logical_mesh_id can be specified." \ "Only one of mesh_shape and logical_mesh_id can be specified."
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id" "Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
)
if logical_mesh_id is None: if logical_mesh_id is None:
self._mesh_shape = mesh_shape self._mesh_shape = mesh_shape
...@@ -71,12 +74,15 @@ class DeviceMesh: ...@@ -71,12 +74,15 @@ class DeviceMesh:
# ensure two things: # ensure two things:
# 1. logical and physical mesh IDs should contain the same elements # 1. logical and physical mesh IDs should contain the same elements
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ assert torch.equal(
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ ), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
"Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again." assert (
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." ), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert (
torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# =============================================== # ===============================================
# coefficient for alpha-beta communication model # coefficient for alpha-beta communication model
...@@ -92,8 +98,9 @@ class DeviceMesh: ...@@ -92,8 +98,9 @@ class DeviceMesh:
self.mesh_beta = tuple(mesh_beta) self.mesh_beta = tuple(mesh_beta)
# ensure the alpha and beta have the same shape # ensure the alpha and beta have the same shape
assert len(self.mesh_alpha) == len(self.mesh_beta), \ assert len(self.mesh_alpha) == len(
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." self.mesh_beta
), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
# ========================= # =========================
# Device for Process Group # Device for Process Group
...@@ -109,8 +116,9 @@ class DeviceMesh: ...@@ -109,8 +116,9 @@ class DeviceMesh:
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...] # <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
# } # }
self._global_to_local_rank_mapping = dict() self._global_to_local_rank_mapping = dict()
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping, self._init_global_to_logical_rank_mapping(
tensor=self.logical_mesh_id) mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
)
# create process group # create process group
self._process_group_dict = {} self._process_group_dict = {}
...@@ -194,8 +202,9 @@ class DeviceMesh: ...@@ -194,8 +202,9 @@ class DeviceMesh:
device_list = [_get_device_by_backend(pg) for pg in process_group] device_list = [_get_device_by_backend(pg) for pg in process_group]
# make sure all devices are the same # make sure all devices are the same
assert all([device == device_list[0] for device in device_list]), \ assert all(
"All devices should be the same, please check your input process groups are created with the same distributed backend." [device == device_list[0] for device in device_list]
), "All devices should be the same, please check your input process groups are created with the same distributed backend."
# create a fake physical mesh id # create a fake physical mesh id
# as we only get the process group associated with the current process, # as we only get the process group associated with the current process,
...@@ -270,7 +279,7 @@ class DeviceMesh: ...@@ -270,7 +279,7 @@ class DeviceMesh:
result = cls.__new__(cls) result = cls.__new__(cls)
memo[id(self)] = result memo[id(self)] = result
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k != '_process_group_dict': if k != "_process_group_dict":
setattr(result, k, __import__("copy").deepcopy(v, memo)) setattr(result, k, __import__("copy").deepcopy(v, memo))
else: else:
# process group cannot be copied # process group cannot be copied
...@@ -278,10 +287,9 @@ class DeviceMesh: ...@@ -278,10 +287,9 @@ class DeviceMesh:
setattr(result, k, v) setattr(result, k, v)
return result return result
def _init_global_to_logical_rank_mapping(self, def _init_global_to_logical_rank_mapping(
mapping: Dict, self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
tensor: torch.Tensor, ) -> Dict[int, List[int]]:
index_list: List[int] = []) -> Dict[int, List[int]]:
""" """
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
...@@ -311,15 +319,19 @@ class DeviceMesh: ...@@ -311,15 +319,19 @@ class DeviceMesh:
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
def init_logical_process_group(self): def init_logical_process_group(self):
''' """
This method is used to initialize the logical process groups which will be used in communications This method is used to initialize the logical process groups which will be used in communications
among logical device mesh. among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise, Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors. the communication related function, such as ShapeConsistencyManager.apply will raise errors.
''' """
# sanity check # sanity check
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" assert (
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" dist.is_initialized
), "The torch.distributed should be initialized before calling init_logical_process_group"
assert (
not self._is_initialized
), "The logical process group has been initialized, do not call init_logical_process_group twice"
# update the global rank of the current process # update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank() self._global_rank_of_current_process = dist.get_rank()
...@@ -389,7 +401,7 @@ class DeviceMesh: ...@@ -389,7 +401,7 @@ class DeviceMesh:
return local_ranks return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank): def _collate_global_ranks_in_same_process_group(self, global_rank):
''' """
Give a global rank and return all global ranks involved in its associated process group in each axis. Give a global rank and return all global ranks involved in its associated process group in each axis.
Example: Example:
...@@ -414,7 +426,7 @@ class DeviceMesh: ...@@ -414,7 +426,7 @@ class DeviceMesh:
0: [0, 4, 8, 12], 0: [0, 4, 8, 12],
1: [0, 1, 2, 3] 1: [0, 1, 2, 3]
# } # }
''' """
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
# for self._global_to_local_rank_mapping # for self._global_to_local_rank_mapping
# the key is the global rank # the key is the global rank
...@@ -437,7 +449,6 @@ class DeviceMesh: ...@@ -437,7 +449,6 @@ class DeviceMesh:
# in the same process group in the given axis # in the same process group in the given axis
# the _local_rank refers to the local rank of the current process # the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]): for _local_rank in range(self.logical_mesh_id.shape[dim]):
# if this dimension is not initialized yet, # if this dimension is not initialized yet,
# initialize it with an empty array # initialize it with an empty array
if dim not in processes_in_the_same_process_group: if dim not in processes_in_the_same_process_group:
...@@ -478,29 +489,37 @@ class DeviceMesh: ...@@ -478,29 +489,37 @@ class DeviceMesh:
flatten_mesh_shape_size = len(self._mesh_shape) flatten_mesh_shape_size = len(self._mesh_shape)
flatten_mesh_shape = [self.num_devices] flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self._physical_mesh_id, return DeviceMesh(
tuple(flatten_mesh_shape), self._physical_mesh_id,
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), tuple(flatten_mesh_shape),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group) mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group,
)
def all_gather_cost(self, num_bytes, mesh_dim): def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim] num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
0.1)
def all_reduce_cost(self, num_bytes, mesh_dim): def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim] num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + return (
0.01) self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
+ 0.01
)
def reduce_scatter_cost(self, num_bytes, mesh_dim): def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim] num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + return (
0.001) self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
)
def all_to_all_cost(self, num_bytes, mesh_dim): def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim] num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0 penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * return (
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
+ 0.001
)
...@@ -2,16 +2,14 @@ from typing import Callable ...@@ -2,16 +2,14 @@ from typing import Callable
import torch import torch
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 12: if TORCH_MAJOR == 1 and TORCH_MINOR < 12:
META_COMPATIBILITY = False META_COMPATIBILITY = False
elif TORCH_MAJOR == 1 and TORCH_MINOR == 12: elif TORCH_MAJOR == 1 and TORCH_MINOR == 12:
from . import _meta_regist_12
META_COMPATIBILITY = True META_COMPATIBILITY = True
elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
from . import _meta_regist_13
META_COMPATIBILITY = True META_COMPATIBILITY = True
elif TORCH_MAJOR == 2: elif TORCH_MAJOR == 2:
META_COMPATIBILITY = True META_COMPATIBILITY = True
...@@ -36,7 +34,7 @@ def compatibility(is_backward_compatible: bool = False) -> Callable: ...@@ -36,7 +34,7 @@ def compatibility(is_backward_compatible: bool = False) -> Callable:
else: else:
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}') raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}")
return wrapper return wrapper
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations # for more meta_registrations
from typing import Callable, List, Optional, Tuple, Union from typing import List, Optional, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -16,13 +16,11 @@ meta_table = {} ...@@ -16,13 +16,11 @@ meta_table = {}
def register_meta(op, register_dispatcher=True): def register_meta(op, register_dispatcher=True):
def wrapper(f): def wrapper(f):
def add_func(op): def add_func(op):
meta_table[op] = f meta_table[op] = f
if register_dispatcher: if register_dispatcher:
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try: try:
meta_lib.impl(name, f) meta_lib.impl(name, f)
except: except:
...@@ -48,7 +46,6 @@ def meta_conv( ...@@ -48,7 +46,6 @@ def meta_conv(
output_padding: List[int], output_padding: List[int],
groups: int, groups: int,
): ):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
""" """
Formula to apply to calculate the length of some dimension of the output Formula to apply to calculate the length of some dimension of the output
...@@ -125,7 +122,8 @@ def meta_conv( ...@@ -125,7 +122,8 @@ def meta_conv(
kernel_size[i], kernel_size[i],
stride[i], stride[i],
output_padding_list[i], output_padding_list[i],
)) )
)
else: else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape return ret_shape
...@@ -159,22 +157,42 @@ def meta_conv( ...@@ -159,22 +157,42 @@ def meta_conv(
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format() mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out return out
@register_meta(aten._convolution.default) @register_meta(aten._convolution.default)
def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], def meta_conv_1(
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, input_tensor: torch.Tensor,
*extra_args): weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args,
):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out return out
@register_meta(aten.convolution_backward.default) @register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, def meta_conv_backward(
padding, dilation, transposed, output_padding, groups, output_mask): grad_output: torch.Tensor,
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') input: torch.Tensor,
weight: torch.Tensor,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta")
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
...@@ -208,7 +226,6 @@ def meta_cuda_rnn( ...@@ -208,7 +226,6 @@ def meta_cuda_rnn(
batch_sizes, batch_sizes,
dropout_state, dropout_state,
): ):
is_input_packed = len(batch_sizes) != 0 is_input_packed = len(batch_sizes) != 0
if is_input_packed: if is_input_packed:
seq_length = len(batch_sizes) seq_length = len(batch_sizes)
...@@ -224,8 +241,11 @@ def meta_cuda_rnn( ...@@ -224,8 +241,11 @@ def meta_cuda_rnn(
if is_input_packed: if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions] out_shape = [batch_sizes_sum, out_size * num_directions]
else: else:
out_shape = ([mini_batch, seq_length, out_size * out_shape = (
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) [mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
output = input.new_empty(out_shape) output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size] cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
...@@ -242,18 +262,20 @@ def meta_cuda_rnn( ...@@ -242,18 +262,20 @@ def meta_cuda_rnn(
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default) @register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor, def meta_cudnn_rnn_backward(
weight: torch.Tensor, input: torch.Tensor,
weight_stride0: int, weight: torch.Tensor,
hx: torch.Tensor, weight_stride0: int,
cx: Optional[torch.Tensor] = None, hx: torch.Tensor,
*args, cx: Optional[torch.Tensor] = None,
**kwargs): *args,
**kwargs,
):
print(input, weight, hx, cx) print(input, weight, hx, cx)
grad_input = torch.empty_like(input) grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight) grad_weight = torch.empty_like(weight)
grad_hx = torch.empty_like(hx) grad_hx = torch.empty_like(hx)
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta")
return grad_input, grad_weight, grad_hx, grad_cx return grad_input, grad_weight, grad_hx, grad_cx
...@@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini ...@@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
n_input = input.size(1) n_input = input.size(1)
output = torch.empty_like(input) output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta') running_mean = torch.empty((n_input), device="meta")
running_var = torch.empty((n_input), device='meta') running_var = torch.empty((n_input), device="meta")
return output, running_mean, running_var return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default) @register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, def meta_bn_backward(
save_invstd, train, eps, output_mask): dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
train,
eps,
output_mask,
):
dX = torch.empty_like(input) dX = torch.empty_like(input)
dgamma = torch.empty_like(weight) dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight) dbeta = torch.empty_like(weight)
...@@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, ...@@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
n_input = input.size(1) n_input = input.size(1)
output = torch.empty_like(input) output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta') running_mean = torch.empty((n_input), device="meta")
running_var = torch.empty((n_input), device='meta') running_var = torch.empty((n_input), device="meta")
reserve = torch.empty((0), dtype=torch.uint8, device='meta') reserve = torch.empty((0), dtype=torch.uint8, device="meta")
return output, running_mean, running_var, reserve return output, running_mean, running_var, reserve
...@@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, ...@@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
# in training mode (evaluation mode batchnorm has a different algorithm), # in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter. # which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default) @register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, def meta_cudnn_bn_backward(
save_mean, save_invstd, eps, reserve): dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
eps,
reserve,
):
dX = torch.empty_like(input) dX = torch.empty_like(input)
dgamma = torch.empty_like(weight) dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight) dbeta = torch.empty_like(weight)
...@@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): ...@@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
n_input = input.size(1) n_input = input.size(1)
output = torch.empty_like(input) output = torch.empty_like(input)
running_mean = torch.empty((bs, n_input, 1), device='meta') running_mean = torch.empty((bs, n_input, 1), device="meta")
running_var = torch.empty((bs, n_input, 1), device='meta') running_var = torch.empty((bs, n_input, 1), device="meta")
return output, running_mean, running_var return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default) @register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, def meta_ln_backward(
grad_input_mask): dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
dX = torch.empty_like(input) dX = torch.empty_like(input)
dgamma = torch.empty_like(weight) dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias) dbeta = torch.empty_like(bias)
...@@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices): ...@@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices):
result: List[Optional[torch.Tensor]] = [] result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices): for i, index in enumerate(indices):
if index is not None: if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\ assert index.dtype in [
"tensors used as indices must be long, byte or bool tensors" torch.long,
torch.int8,
torch.bool,
], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]: if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero() nonzero = index.nonzero()
k = len(result) k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim): for j in range(index.ndim):
assert index.shape[j] == self.shape[ assert (
k + index.shape[j] == self.shape[k + j]
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j)) result.append(nonzero.select(1, j))
else: else:
result.append(index) result.append(index)
...@@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices): ...@@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices):
# ============================== Embedding ========================================= # ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default) @register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, def meta_embedding_dense_backward(
scale_grad_by_freq): grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
return torch.empty((num_weights, grad_output.size(-1)), ):
dtype=grad_output.dtype, return torch.empty(
device=grad_output.device, (num_weights, grad_output.size(-1)),
layout=grad_output.layout) dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout,
)
# ============================== Dropout =========================================== # ============================== Dropout ===========================================
......
from typing import Any, Callable, Dict, Iterable, List, Tuple from typing import Any, Dict, Iterable, List, Tuple
import torch import torch
...@@ -18,6 +18,7 @@ try: ...@@ -18,6 +18,7 @@ try:
magic_methods, magic_methods,
) )
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = True CODEGEN_AVAILABLE = True
except: except:
from torch.fx.graph import ( from torch.fx.graph import (
...@@ -32,12 +33,13 @@ except: ...@@ -32,12 +33,13 @@ except:
magic_methods, magic_methods,
) )
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = False CODEGEN_AVAILABLE = False
if CODEGEN_AVAILABLE: if CODEGEN_AVAILABLE:
__all__ = ['ActivationCheckpointCodeGen'] __all__ = ["ActivationCheckpointCodeGen"]
else: else:
__all__ = ['python_code_with_activation_checkpoint'] __all__ = ["python_code_with_activation_checkpoint"]
def _gen_saved_tensors_hooks(): def _gen_saved_tensors_hooks():
...@@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]): ...@@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]):
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index). of tuples, each tuple is in the form of (start_index, end_index).
""" """
ckpt_nodes = []
ckpt_regions = [] ckpt_regions = []
start = -1 start = -1
end = -1 end = -1
current_region = None current_region = None
for idx, node in enumerate(nodes): for idx, node in enumerate(nodes):
if 'activation_checkpoint' in node.meta: if "activation_checkpoint" in node.meta:
act_ckpt_label = node.meta['activation_checkpoint'] act_ckpt_label = node.meta["activation_checkpoint"]
# this activation checkpoint label is not set yet # this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region # meaning this is the first node of the activation ckpt region
...@@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]): ...@@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = act_ckpt_label current_region = act_ckpt_label
start = idx start = idx
end = -1 end = -1
elif current_region is not None and not 'activation_checkpoint' in node.meta: elif current_region is not None and not "activation_checkpoint" in node.meta:
# used to check the case below # used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt] # node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1 end = idx - 1
...@@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]): ...@@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None current_region = None
for idx, node in enumerate(nodes): for idx, node in enumerate(nodes):
if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable): if "activation_offload" in node.meta and isinstance(node.meta["activation_offload"], Iterable):
act_offload_label = node.meta['activation_offload'] act_offload_label = node.meta["activation_offload"]
if current_region == None: if current_region == None:
current_region = act_offload_label current_region = act_offload_label
...@@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen ...@@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
""" """
Generate the checkpoint function call code text Generate the checkpoint function call code text
""" """
outputs = ', '.join(output_vars) outputs = ", ".join(output_vars)
inputs = ', '.join(input_vars) inputs = ", ".join(input_vars)
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' return f"{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, check_idx: int) -> bool: def _end_of_ckpt(node: Node, check_idx: int) -> bool:
...@@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool: ...@@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
Returns: Returns:
bool bool
""" """
if 'activation_checkpoint' in node.meta: if "activation_checkpoint" in node.meta:
if isinstance(node.meta['activation_checkpoint'], list): if isinstance(node.meta["activation_checkpoint"], list):
return node.meta['activation_checkpoint'][check_idx] == None return node.meta["activation_checkpoint"][check_idx] == None
else: else:
return False return False
else: else:
...@@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0): ...@@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region = None current_region = None
for idx, node in enumerate(nodes): for idx, node in enumerate(nodes):
if 'activation_checkpoint' in node.meta: if "activation_checkpoint" in node.meta:
if isinstance(node.meta['activation_checkpoint'], int): if isinstance(node.meta["activation_checkpoint"], int):
act_ckpt_label = node.meta['activation_checkpoint'] act_ckpt_label = node.meta["activation_checkpoint"]
else: else:
act_ckpt_label = node.meta['activation_checkpoint'][check_idx] act_ckpt_label = node.meta["activation_checkpoint"][check_idx]
# this activation checkpoint label is not set yet # this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region # meaning this is the first node of the activation ckpt region
...@@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0): ...@@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
return ckpt_regions return ckpt_regions
def emit_ckpt_func(body, def emit_ckpt_func(
ckpt_func, body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, level=0, in_ckpt=False
node_list: List[Node], ):
emit_node_func,
delete_unused_value_func,
level=0,
in_ckpt=False):
"""Emit ckpt function in nested way """Emit ckpt function in nested way
Args: Args:
body: forward code, in recursive calls, this part will be checkpoint body: forward code, in recursive calls, this part will be checkpoint
...@@ -321,17 +318,17 @@ def emit_ckpt_func(body, ...@@ -321,17 +318,17 @@ def emit_ckpt_func(body,
inputs, outputs = _find_input_and_output_nodes(node_list) inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method # if the current checkpoint function use int as label, using old generation method
if isinstance(node_list[0].meta['activation_checkpoint'], int): if isinstance(node_list[0].meta["activation_checkpoint"], int):
label = node_list[0].meta['activation_checkpoint'] label = node_list[0].meta["activation_checkpoint"]
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n') ckpt_func.append(f"{ckpt_fn_def}\n")
for node in node_list: for node in node_list:
emit_node_func(node, ckpt_func) emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1] ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func) delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
activation_offload = node_list[0].meta.get('activation_offload', False) activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n" usage += "\n"
body.append(usage) body.append(usage)
...@@ -340,12 +337,12 @@ def emit_ckpt_func(body, ...@@ -340,12 +337,12 @@ def emit_ckpt_func(body,
else: else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1] # label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1' # the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]]) label = "_".join([str(idx) for idx in node_list[0].meta["activation_checkpoint"][: level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n') ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch # if there is more level to fetch
if level + 1 < len(node_list[0].meta['activation_checkpoint']): if level + 1 < len(node_list[0].meta["activation_checkpoint"]):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions] start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions]
...@@ -358,38 +355,45 @@ def emit_ckpt_func(body, ...@@ -358,38 +355,45 @@ def emit_ckpt_func(body,
break break
if node_idx in start_idx: if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, emit_ckpt_func(
delete_unused_value_func, level + 1, True) ckpt_func,
ckpt_func_buffer,
ckpt_node_list,
emit_node_func,
delete_unused_value_func,
level + 1,
True,
)
node_idx += len(ckpt_node_list) node_idx += len(ckpt_node_list)
else: else:
node = node_list[node_idx] node = node_list[node_idx]
emit_node_func(node, ckpt_func) emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1] ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func) delete_unused_value_func(node, ckpt_func)
node_idx += 1 node_idx += 1
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer ckpt_func += ckpt_func_buffer
activation_offload = node_list[0].meta.get('activation_offload', False) activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt: if in_ckpt:
usage = ' ' + usage usage = " " + usage
body.append(usage) body.append(usage)
# last level # last level
else: else:
for node in node_list: for node in node_list:
emit_node_func(node, ckpt_func) emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1] ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func) delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
activation_offload = node_list[0].meta.get('activation_offload', False) activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt: if in_ckpt:
usage = ' ' + usage usage = " " + usage
body.append(usage) body.append(usage)
...@@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod ...@@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# find the input and output var names for each offload region # find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions): for idx, (start, end) in enumerate(offload_regions):
offload_node_list = node_list[start:end + 1] offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list) inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs) offload_inputs.append(inputs)
offload_outputs.append(outputs) offload_outputs.append(outputs)
...@@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod ...@@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# process ckpt_regions # process ckpt_regions
if node_idx in start_idx: if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list) node_idx += len(ckpt_node_list)
...@@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod ...@@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
if within_offload_region: if within_offload_region:
emit_node_func(node, body) emit_node_func(node, body)
body[-1] = ' ' + body[-1] body[-1] = " " + body[-1]
delete_unused_value_func(node, body) delete_unused_value_func(node, body)
else: else:
...@@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# find the input and output var names for each region # find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions): for idx, (start, end) in enumerate(ckpt_regions):
ckpt_node_list = node_list[start:end + 1] ckpt_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) inputs, outputs = _find_input_and_output_nodes(ckpt_node_list)
input_vars.append(inputs) input_vars.append(inputs)
output_vars.append(outputs) output_vars.append(outputs)
# find the input and output var names for each offload region # find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions): for idx, (start, end) in enumerate(offload_regions):
offload_node_list = node_list[start:end + 1] offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list) inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs) offload_inputs.append(inputs)
offload_outputs.append(outputs) offload_outputs.append(outputs)
...@@ -527,7 +531,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -527,7 +531,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if idx in start_idx: if idx in start_idx:
label = start_idx.index(idx) label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
ckpt_func.append(f'{ckpt_fn_def}\n') ckpt_func.append(f"{ckpt_fn_def}\n")
within_ckpt_region = True within_ckpt_region = True
if idx in offload_starts: if idx in offload_starts:
...@@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# NOTE: currently we separate body and ckpt_func definition # NOTE: currently we separate body and ckpt_func definition
if within_ckpt_region: if within_ckpt_region:
emit_node_func(node, ckpt_func) emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1] ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func) delete_unused_value_func(node, ckpt_func)
elif within_offload_region: elif within_offload_region:
emit_node_func(node, body) emit_node_func(node, body)
body[-1] = ' ' + body[-1] body[-1] = " " + body[-1]
delete_unused_value_func(node, body) delete_unused_value_func(node, body)
else: else:
...@@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate return statement # generate return statement
label = end_idx.index(idx) label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label]) return_statement = _gen_ckpt_output(output_vars[label])
return_statement = f' {return_statement}\n\n' return_statement = f" {return_statement}\n\n"
ckpt_func.append(return_statement) ckpt_func.append(return_statement)
# we need to check if the checkpoint need to offload the input # we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label] start_node_idx = start_idx[label]
if 'activation_offload' in node_list[start_node_idx].meta: if "activation_offload" in node_list[start_node_idx].meta:
activation_offload = node_list[start_node_idx].meta['activation_offload'] activation_offload = node_list[start_node_idx].meta["activation_offload"]
else: else:
activation_offload = False activation_offload = False
...@@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if input_node.op != "placeholder": if input_node.op != "placeholder":
non_leaf_input = 1 non_leaf_input = 1
for user in input_node.users: for user in input_node.users:
if 'activation_checkpoint' in user.meta: if "activation_checkpoint" in user.meta:
if user.meta['activation_checkpoint'] == label: if user.meta["activation_checkpoint"] == label:
if user.op == "call_module": if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
...@@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate checkpoint function call in a new line # generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
usage += '\n' usage += "\n"
body.append(usage) body.append(usage)
within_ckpt_region = False within_ckpt_region = False
...@@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if CODEGEN_AVAILABLE: if CODEGEN_AVAILABLE:
class ActivationCheckpointCodeGen(CodeGen): class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = [] free_vars: List[str] = []
body: List[str] = [] body: List[str] = []
...@@ -629,7 +632,7 @@ if CODEGEN_AVAILABLE: ...@@ -629,7 +632,7 @@ if CODEGEN_AVAILABLE:
wrapped_fns: Dict[str, None] = {} wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference # Wrap string in list to pass by reference
maybe_return_annotation: List[str] = [''] maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any): def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global. """Add an obj to be tracked as a global.
...@@ -637,7 +640,7 @@ if CODEGEN_AVAILABLE: ...@@ -637,7 +640,7 @@ if CODEGEN_AVAILABLE:
Graph, like functions or types. Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source. Returns: the global name that should be used to reference 'obj' in generated source.
""" """
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We # HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their # can't import them like normal modules so they must retain their
# fully qualified name. # fully qualified name.
...@@ -662,16 +665,16 @@ if CODEGEN_AVAILABLE: ...@@ -662,16 +665,16 @@ if CODEGEN_AVAILABLE:
def type_repr(o: Any): def type_repr(o: Any):
if o == (): if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()] # Empty tuple is used for empty tuple type annotation Tuple[()]
return '()' return "()"
typename = _type_repr(o) typename = _type_repr(o)
if hasattr(o, '__origin__'): if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor] # This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type) origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, '__args__'): if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables. # Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__] args = [type_repr(arg) for arg in o.__args__]
...@@ -690,19 +693,18 @@ if CODEGEN_AVAILABLE: ...@@ -690,19 +693,18 @@ if CODEGEN_AVAILABLE:
return add_global(typename, o) return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg): def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global. # Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'): if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg)) qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg)) global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}" return f"{global_name}{repr(tuple(arg))}"
return repr(arg) return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args) args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s: if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}' return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use # Run through reverse nodes and record the first instance of a use
...@@ -728,90 +730,101 @@ if CODEGEN_AVAILABLE: ...@@ -728,90 +730,101 @@ if CODEGEN_AVAILABLE:
not used in the remainder of the code are freed and the memory usage not used in the remainder of the code are freed and the memory usage
of the code is optimal. of the code is optimal.
""" """
if user.op == 'placeholder': if user.op == "placeholder":
return return
if user.op == 'output': if user.op == "output":
body.append('\n') body.append("\n")
return return
nodes_to_delete = user_to_last_uses.get(user, []) nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete): if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f'; {to_delete_str}\n') body.append(f"; {to_delete_str}\n")
else: else:
body.append('\n') body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func # NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body): def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == 'placeholder': if node.op == "placeholder":
assert isinstance(node.target, str) assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace('*', '') raw_name = node.target.replace("*", "")
if raw_name != repr(node): if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n') body.append(f"{repr(node)} = {raw_name}\n")
return return
elif node.op == 'call_method': elif node.op == "call_method":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append( body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f'({_format_args(node.args[1:], node.kwargs)})') f"({_format_args(node.args[1:], node.kwargs)})"
)
return return
elif node.op == 'call_function': elif node.op == "call_function":
assert callable(node.target) assert callable(node.target)
# pretty print operators # pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple) assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = ' body.append(
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return return
# pretty print inplace operators; required for jit.script to work properly # pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo # not currently supported in normal FX graphs, but generated by torchdynamo
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' body.append(
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return return
qualified_name = _get_qualified_name(node.target) qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target) global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument # special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \ if (
isinstance(node.args, tuple) and \ global_name == "getattr"
isinstance(node.args[1], str) and \ and isinstance(node.args, tuple)
node.args[1].isidentifier() and \ and isinstance(node.args[1], str)
len(node.args) == 2: and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append( body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return return
body.append( body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
if node.meta.get('is_wrapped', False): )
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name) wrapped_fns.setdefault(global_name)
return return
elif node.op == 'call_module': elif node.op == "call_module":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = ' body.append(
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return return
elif node.op == 'get_attr': elif node.op == "get_attr":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return return
elif node.op == 'output': elif node.op == "output":
if node.type is not None: if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}" maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0])) body.append(self.generate_output(node.args[0]))
return return
raise NotImplementedError(f'node: {node.op} {node.target}') raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing # Modified for activation checkpointing
ckpt_func = [] ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we # if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen # will use nested type of activation checkpoint codegen
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes): if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else: else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
...@@ -820,13 +833,13 @@ if CODEGEN_AVAILABLE: ...@@ -820,13 +833,13 @@ if CODEGEN_AVAILABLE:
# If the Graph has no non-placeholder nodes, no lines for the body # If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a # have been emitted. To continue to have valid Python code, emit a
# single pass statement # single pass statement
body.append('pass\n') body.append("pass\n")
if len(wrapped_fns) > 0: if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap) wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else: else:
wrap_stmts = '' wrap_stmts = ""
if self._body_transformer: if self._body_transformer:
body = self._body_transformer(body) body = self._body_transformer(body)
...@@ -837,11 +850,11 @@ if CODEGEN_AVAILABLE: ...@@ -837,11 +850,11 @@ if CODEGEN_AVAILABLE:
# as we need colossalai.utils.checkpoint, we need to import colossalai # as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function # in forward function
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = ''.join(ckpt_func) + prologue prologue = "".join(ckpt_func) + prologue
prologue = prologue prologue = prologue
code = ''.join(body) code = "".join(body)
code = '\n'.join(' ' + line for line in code.split('\n')) code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f""" fn_code = f"""
{wrap_stmts} {wrap_stmts}
{prologue} {prologue}
...@@ -861,7 +874,7 @@ else: ...@@ -861,7 +874,7 @@ else:
wrapped_fns: Dict[str, None] = {} wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference # Wrap string in list to pass by reference
maybe_return_annotation: List[str] = [''] maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any): def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global. """Add an obj to be tracked as a global.
...@@ -869,7 +882,7 @@ else: ...@@ -869,7 +882,7 @@ else:
Graph, like functions or types. Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source. Returns: the global name that should be used to reference 'obj' in generated source.
""" """
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We # HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their # can't import them like normal modules so they must retain their
# fully qualified name. # fully qualified name.
...@@ -894,12 +907,12 @@ else: ...@@ -894,12 +907,12 @@ else:
def type_repr(o: Any): def type_repr(o: Any):
if o == (): if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()] # Empty tuple is used for empty tuple type annotation Tuple[()]
return '()' return "()"
typename = _type_repr(o) typename = _type_repr(o)
# This is a generic type, e.g. typing.List[torch.Tensor] # This is a generic type, e.g. typing.List[torch.Tensor]
if hasattr(o, '__origin__'): if hasattr(o, "__origin__"):
origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type) origin_typename = add_global(_type_repr(origin_type), origin_type)
...@@ -934,84 +947,94 @@ else: ...@@ -934,84 +947,94 @@ else:
not used in the remainder of the code are freed and the memory usage not used in the remainder of the code are freed and the memory usage
of the code is optimal. of the code is optimal.
""" """
if user.op == 'placeholder': if user.op == "placeholder":
return return
if user.op == 'output': if user.op == "output":
body.append('\n') body.append("\n")
return return
nodes_to_delete = user_to_last_uses.get(user, []) nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete): if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f'; {to_delete_str}\n') body.append(f"; {to_delete_str}\n")
else: else:
body.append('\n') body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func # NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body): def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == 'placeholder': if node.op == "placeholder":
assert isinstance(node.target, str) assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace('*', '') raw_name = node.target.replace("*", "")
if raw_name != repr(node): if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n') body.append(f"{repr(node)} = {raw_name}\n")
return return
elif node.op == 'call_method': elif node.op == "call_method":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' body.append(
f'({_format_args(node.args[1:], node.kwargs)})') f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})"
)
return return
elif node.op == 'call_function': elif node.op == "call_function":
assert callable(node.target) assert callable(node.target)
# pretty print operators # pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple) assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = ' body.append(
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return return
qualified_name = _get_qualified_name(node.target) qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target) global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument # special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \ if (
isinstance(node.args, tuple) and \ global_name == "getattr"
isinstance(node.args[1], str) and \ and isinstance(node.args, tuple)
node.args[1].isidentifier() and \ and isinstance(node.args[1], str)
len(node.args) == 2: and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append( body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return return
body.append( body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
if node.meta.get('is_wrapped', False): )
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name) wrapped_fns.setdefault(global_name)
return return
elif node.op == 'call_module': elif node.op == "call_module":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = ' body.append(
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return return
elif node.op == 'get_attr': elif node.op == "get_attr":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return return
elif node.op == 'output': elif node.op == "output":
if node.type is not None: if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}" maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
if self._pytree_info is None: if self._pytree_info is None:
body.append(f'return {repr(node.args[0])}') body.append(f"return {repr(node.args[0])}")
else: else:
body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') body.append(f"return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)")
return return
raise NotImplementedError(f'node: {node.op} {node.target}') raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing # Modified for activation checkpointing
ckpt_func = [] ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we # if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen # will use nested type of activation checkpoint codegen
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes): if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else: else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
...@@ -1020,33 +1043,34 @@ else: ...@@ -1020,33 +1043,34 @@ else:
# If the Graph has no non-placeholder nodes, no lines for the body # If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a # have been emitted. To continue to have valid Python code, emit a
# single pass statement # single pass statement
body.append('pass\n') body.append("pass\n")
if self._pytree_info is not None: if self._pytree_info is not None:
orig_args = self._pytree_info.orig_args orig_args = self._pytree_info.orig_args
has_orig_self = (orig_args[0] == 'self') has_orig_self = orig_args[0] == "self"
if has_orig_self: if has_orig_self:
free_vars.insert(0, 'self') free_vars.insert(0, "self")
if len(free_vars) > 0: # pytree has placeholders in it if len(free_vars) > 0: # pytree has placeholders in it
body.insert( body.insert(
0, 0,
f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n",
)
else: else:
orig_args = free_vars orig_args = free_vars
if len(wrapped_fns) > 0: if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap) wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else: else:
wrap_stmts = '' wrap_stmts = ""
ckpt_func = ''.join(ckpt_func) ckpt_func = "".join(ckpt_func)
# If the original function didn't have self as its first argument, we # If the original function didn't have self as its first argument, we
# would have added it. # would have added it.
if len(orig_args) == 0 or orig_args[0] != 'self': if len(orig_args) == 0 or orig_args[0] != "self":
orig_args.insert(0, 'self') orig_args.insert(0, "self")
code = ''.join(body) code = "".join(body)
code = '\n'.join(' ' + line for line in code.split('\n')) code = "\n".join(" " + line for line in code.split("\n"))
# as we need colossalai.utils.checkpoint, we need to import colossalai # as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function # in forward function
......
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, Union from typing import Any, Dict, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.modules.module import _addindent from torch.nn.modules.module import _addindent
try: try:
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
COLOGM = True COLOGM = True
except: except:
from torch.fx.graph import Graph from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule from torch.fx.graph_module import GraphModule
COLOGM = False COLOGM = False
if COLOGM: if COLOGM:
class ColoGraphModule(GraphModule): class ColoGraphModule(GraphModule):
def __init__(
def __init__(self, self,
root: Union[torch.nn.Module, Dict[str, Any]], root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph, graph: Graph,
class_name: str = 'GraphModule', class_name: str = "GraphModule",
ckpt_codegen: bool = True): ckpt_codegen: bool = True,
):
if ckpt_codegen: if ckpt_codegen:
graph.set_codegen(ActivationCheckpointCodeGen()) graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name) super().__init__(root, graph, class_name)
...@@ -60,7 +63,7 @@ if COLOGM: ...@@ -60,7 +63,7 @@ if COLOGM:
if isinstance(self._graph._codegen, _PyTreeCodeGen): if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self') python_code = self._graph.python_code(root_module="self")
self._code = python_code.src self._code = python_code.src
# To split ckpt functions code and forward code # To split ckpt functions code and forward code
...@@ -83,8 +86,8 @@ if COLOGM: ...@@ -83,8 +86,8 @@ if COLOGM:
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing. # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls): if "_wrapped_call" not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs): def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs) return self._wrapped_call(self, *args, **kwargs)
...@@ -108,7 +111,7 @@ if COLOGM: ...@@ -108,7 +111,7 @@ if COLOGM:
""" """
folder = Path(folder) folder = Path(folder)
Path(folder).mkdir(exist_ok=True) Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt') torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4 tab = " " * 4
# we add import colossalai here # we add import colossalai here
...@@ -125,7 +128,13 @@ class {module_name}(torch.nn.Module): ...@@ -125,7 +128,13 @@ class {module_name}(torch.nn.Module):
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [ safe_reprs = [
nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
] ]
if type(module) in safe_reprs: if type(module) in safe_reprs:
return f"{module.__repr__()}" return f"{module.__repr__()}"
...@@ -136,10 +145,10 @@ class {module_name}(torch.nn.Module): ...@@ -136,10 +145,10 @@ class {module_name}(torch.nn.Module):
for module_name, module in self.named_children(): for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module) module_str = _gen_model_repr(module_name, module)
if module_str is None: if module_str is None:
module_file = folder / f'{module_name}.pt' module_file = folder / f"{module_name}.pt"
torch.save(module, module_file) torch.save(module, module_file)
blobified_modules.append(module_name) blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}" module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n" model_str += f"{tab*2}self.{module_name} = {module_str}\n"
...@@ -156,19 +165,20 @@ class {module_name}(torch.nn.Module): ...@@ -156,19 +165,20 @@ class {module_name}(torch.nn.Module):
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n" model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py' module_file = folder / "module.py"
module_file.write_text(model_str) module_file.write_text(model_str)
init_file = folder / '__init__.py' init_file = folder / "__init__.py"
init_file.write_text('from .module import *') init_file.write_text("from .module import *")
if len(blobified_modules) > 0: if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -" warnings.warn(
f"saved as pickled files instead: {blobified_modules}") "Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}"
)
else: else:
class ColoGraphModule(GraphModule): class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
super().__init__(root, graph, class_name) super().__init__(root, graph, class_name)
import numpy as np import numpy as np
import torch import torch
import tqdm import tqdm
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module from colossalai.fx.passes.split_module import split_module
...@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): ...@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
accumulate_bwd_flop = 0 accumulate_bwd_flop = 0
block_nodes = [] block_nodes = []
for node in gm.graph.nodes: for node in gm.graph.nodes:
if 'block_split' in node.name: if "block_split" in node.name:
continue continue
accumulate_fwd_flop += node.fwd_flop accumulate_fwd_flop += node.fwd_flop
accumulate_bwd_flop += node.bwd_flop accumulate_bwd_flop += node.bwd_flop
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop: if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
with gm.graph.inserting_after(node): with gm.graph.inserting_after(node):
block_node = gm.graph.create_node('call_function', block_split) block_node = gm.graph.create_node("call_function", block_split)
setattr(block_node, 'fwd_flop', accumulate_fwd_flop) setattr(block_node, "fwd_flop", accumulate_fwd_flop)
setattr(block_node, 'bwd_flop', accumulate_bwd_flop) setattr(block_node, "bwd_flop", accumulate_bwd_flop)
accumulate_fwd_flop = 0 accumulate_fwd_flop = 0
accumulate_bwd_flop = 0 accumulate_bwd_flop = 0
block_nodes.append(block_node) block_nodes.append(block_node)
...@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): ...@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
def remove_blocks(gm: torch.fx.GraphModule): def remove_blocks(gm: torch.fx.GraphModule):
for node in gm.graph.nodes: for node in gm.graph.nodes:
if (node.op, node.target) == ('call_function', block_split): if (node.op, node.target) == ("call_function", block_split):
gm.graph.erase_node(node) gm.graph.erase_node(node)
...@@ -55,8 +53,8 @@ def get_compute_costs(node_list): ...@@ -55,8 +53,8 @@ def get_compute_costs(node_list):
num_nodes = len(node_list) num_nodes = len(node_list)
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64) all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0): for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0):
for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False): for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False):
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)] selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
all_compute_cost[start, end] = sum(selected_flops) all_compute_cost[start, end] = sum(selected_flops)
...@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost ...@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost
# record start node index for next stage in this partition # record start node index for next stage in this partition
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32) f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
f[0, num_nodes] = 0 f[0, num_nodes] = 0
for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks for s in tqdm.tqdm(
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False): range(1, num_stages + 1), desc="stage", position=2, leave=False
for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False): ): # pylint: disable=too-many-nested-blocks
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False):
for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False):
stage_cost = compute_costs[i, k - 1] stage_cost = compute_costs[i, k - 1]
new_cost = f[s - 1, k] + stage_cost new_cost = f[s - 1, k] + stage_cost
if (stage_cost <= max_compute_cost and new_cost < f[s, i]): if stage_cost <= max_compute_cost and new_cost < f[s, i]:
f[s, i] = new_cost f[s, i] = new_cost
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost) f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
f_argmin[s, i] = k f_argmin[s, i] = k
...@@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche ...@@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
best_cost = np.inf best_cost = np.inf
best_solution = None best_solution = None
last_max_compute_cost = 0.0 last_max_compute_cost = 0.0
gap = 1e6 # temporary magic number, unit: flops gap = 1e6 # temporary magic number, unit: flops
for max_compute_cost in tqdm.tqdm(max_compute_costs): for max_compute_cost in tqdm.tqdm(max_compute_costs):
# Pruning to reduce search space. # Pruning to reduce search space.
...@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche ...@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
if max_compute_cost - last_max_compute_cost < gap: if max_compute_cost - last_max_compute_cost < gap:
continue continue
cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs, cost, solution = do_dp_split_gpipe_impl(
max_compute_cost) len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost
)
if cost < best_cost: if cost < best_cost:
best_cost = cost best_cost = cost
...@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche ...@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
# split_mode: # split_mode:
# 'node': fx_node # 'node': fx_node
# 'block': many fx_nodes construct a block # 'block': many fx_nodes construct a block
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01): def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01):
assert mode in ['node', 'block'] assert mode in ["node", "block"]
# nodes or blocks will be used in partition. # nodes or blocks will be used in partition.
node_list = [] node_list = []
if mode == 'node': if mode == "node":
for node in gm.graph.nodes: for node in gm.graph.nodes:
node_list.append(node) node_list.append(node)
elif mode == 'block': elif mode == "block":
node_list = construct_blocks(gm, limit=block_limit) node_list = construct_blocks(gm, limit=block_limit)
else: else:
pass pass
...@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches ...@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches) best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
for (_, next_start_node) in best_solution: for _, next_start_node in best_solution:
if pp_size <= 1: if pp_size <= 1:
break break
node = node_list[next_start_node] node = node_list[next_start_node]
with gm.graph.inserting_before(node): with gm.graph.inserting_before(node):
split_node = gm.graph.create_node('call_function', pipe_split) split_node = gm.graph.create_node("call_function", pipe_split)
pp_size -= 1 pp_size -= 1
# remove block node if possible # remove block node if possible
if mode == 'block': if mode == "block":
remove_blocks(gm) remove_blocks(gm)
gm.recompile() gm.recompile()
...@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): ...@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first. # To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass. # If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0] check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta: if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size) return balanced_split_pass(gm, pp_size)
total_fwd_flop = 0 total_fwd_flop = 0
...@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): ...@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes: for node in mod_graph.nodes:
if pp_size <= 1: if pp_size <= 1:
break break
if 'pipe_split' in node.name: if "pipe_split" in node.name:
continue continue
accumulate_fwd_flop += node.fwd_flop accumulate_fwd_flop += node.fwd_flop
if accumulate_fwd_flop >= partition_flop: if accumulate_fwd_flop >= partition_flop:
...@@ -199,7 +200,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): ...@@ -199,7 +200,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1 pp_size -= 1
partition_flop = total_fwd_flop // pp_size partition_flop = total_fwd_flop // pp_size
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile() gm.recompile()
return gm return gm
...@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int): ...@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if accumulate_num_node >= avg_num_node: if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0 accumulate_num_node = 0
pp_size -= 1 pp_size -= 1
if node.next.op == 'output': if node.next.op == "output":
with mod_graph.inserting_before(node): with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node("call_function", pipe_split)
else: else:
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile() gm.recompile()
return gm return gm
...@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): ...@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1 pp_size -= 1
# If the next node is output node, we will insert split annotation before # If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition. # node to make sure there is at least one node in last partition.
if node.next.op == 'output': if node.next.op == "output":
with mod_graph.inserting_before(node): with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node("call_function", pipe_split)
else: else:
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node("call_function", pipe_split)
if pp_size > 1: if pp_size > 1:
node_counter = 0 node_counter = 0
for node in mod_graph.nodes: for node in mod_graph.nodes:
if pp_size <= 1: if pp_size <= 1:
break break
if node.op == 'placeholder': if node.op == "placeholder":
continue continue
elif node_counter == 0: elif node_counter == 0:
node_counter += 1 node_counter += 1
...@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): ...@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1 pp_size -= 1
node_counter = 0 node_counter = 0
with mod_graph.inserting_before(node): with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile() gm.recompile()
return gm return gm
...@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): ...@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first. # To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass. # If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0] check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta: if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size) return balanced_split_pass(gm, pp_size)
total_element_size = 0 total_element_size = 0
...@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): ...@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes: for node in mod_graph.nodes:
if pp_size <= 1: if pp_size <= 1:
break break
if 'pipe_split' in node.name: if "pipe_split" in node.name:
continue continue
accumulate_node_size += node.node_size accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size: if accumulate_node_size >= partition_size:
...@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): ...@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1 pp_size -= 1
partition_size = total_element_size // pp_size partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile() gm.recompile()
return gm return gm
...@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): ...@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
accumulate_layer_amount = 0 accumulate_layer_amount = 0
pp_size -= 1 pp_size -= 1
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile() gm.recompile()
return gm return gm
...@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output ...@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
def split_callback(n: torch.fx.Node): def split_callback(n: torch.fx.Node):
nonlocal part_idx nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split): if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1 part_idx += 1
return part_idx return part_idx
...@@ -355,7 +356,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output ...@@ -355,7 +356,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
for name, submodule in split_mod.named_modules(): for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule): if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes: for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split): if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node) submodule.graph.erase_node(node)
submodule.recompile() submodule.recompile()
split_submodules.append(submodule) split_submodules.append(submodule)
......
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import torch.fx import torch.fx
...@@ -85,10 +85,10 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -85,10 +85,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
self._is_proped = True self._is_proped = True
result, meta_info = super().run_node(n) result, meta_info = super().run_node(n)
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future # TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0))
n.meta['type'] = type(result) n.meta["type"] = type(result)
# retain the autograd graph # retain the autograd graph
for param in self.module.parameters(): for param in self.module.parameters():
...@@ -98,7 +98,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -98,7 +98,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
# Main Node running APIs # Main Node running APIs
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``placeholder`` node. Note that this is stateful: Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over ``Interpreter`` maintains an internal iterator over
...@@ -119,7 +119,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -119,7 +119,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().placeholder(target, args, kwargs), GraphInfo() return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``get_attr`` node. Will retrieve an attribute Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``. value from the ``Module`` hierarchy of ``self.module``.
...@@ -138,7 +138,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -138,7 +138,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().get_attr(target, args, kwargs), GraphInfo() return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_function`` node with meta tensor and return the result and its meta profile. Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
...@@ -157,7 +157,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -157,7 +157,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_function(target, self.device)(*args, **kwargs) return profile_function(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_method`` node with meta tensor and return the result and its meta profile. Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
...@@ -175,7 +175,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -175,7 +175,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_method(target, self.device)(*args, **kwargs) return profile_method(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_module`` node with meta tensor and return the result and its meta profile. Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
...@@ -197,7 +197,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -197,7 +197,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_module(submod, self.device)(*args, **kwargs) return profile_module(submod, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute an ``output`` node. This really just retrieves Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it. the value referenced by the ``output`` node and returns it.
...@@ -228,7 +228,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -228,7 +228,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
""" """
return self.run(*args) return self.run(*args)
def summary(self, unit: str = 'MB') -> str: def summary(self, unit: str = "MB") -> str:
""" """
Summarizes the memory and FLOPs statistics of the `GraphModule` in Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module tabular format. Note that this API requires the ``tabulate`` module
...@@ -238,9 +238,11 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -238,9 +238,11 @@ class ConcreteInfoProp(torch.fx.Interpreter):
try: try:
from tabulate import tabulate from tabulate import tabulate
except ImportError: except ImportError:
print("`summary` relies on the library `tabulate`, " print(
"which could not be found on this machine. Run `pip " "`summary` relies on the library `tabulate`, "
"install tabulate` to install the library.") "which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
...@@ -249,10 +251,10 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -249,10 +251,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
def mem_repr(mem: int) -> str: def mem_repr(mem: int) -> str:
unit_divisor_map = { unit_divisor_map = {
'kb': 1024, "kb": 1024,
'mb': 1024**2, "mb": 1024**2,
'gb': 1024**3, "gb": 1024**3,
'tb': 1024**4, "tb": 1024**4,
} }
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
...@@ -261,30 +263,32 @@ class ConcreteInfoProp(torch.fx.Interpreter): ...@@ -261,30 +263,32 @@ class ConcreteInfoProp(torch.fx.Interpreter):
for node in self.module.graph.nodes: for node in self.module.graph.nodes:
node: Node node: Node
node_summaries.append([ node_summaries.append(
node.op, [
str(node), node.op,
time_repr(node.meta['fwd_time']), str(node),
time_repr(node.meta['bwd_time']), time_repr(node.meta["fwd_time"]),
node.meta['save_fwd_in'], time_repr(node.meta["bwd_time"]),
mem_repr(node.meta['fwd_mem_out']), node.meta["save_fwd_in"],
mem_repr(node.meta['fwd_mem_tmp']), mem_repr(node.meta["fwd_mem_out"]),
mem_repr(node.meta['bwd_mem_out']), mem_repr(node.meta["fwd_mem_tmp"]),
mem_repr(node.meta['bwd_mem_tmp']), mem_repr(node.meta["bwd_mem_out"]),
]) mem_repr(node.meta["bwd_mem_tmp"]),
]
)
# Use the ``tabulate`` library to create a well-formatted table # Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information # presenting our summary information
headers: List[str] = [ headers: List[str] = [
'Op type', "Op type",
'Op', "Op",
'Forward time', "Forward time",
'Backward time', "Backward time",
'SAVE_FWD_IN', "SAVE_FWD_IN",
'FWD_OUT', "FWD_OUT",
'FWD_TMP', "FWD_TMP",
'BWD_OUT', "BWD_OUT",
'BWD_TMP', "BWD_TMP",
] ]
return tabulate(node_summaries, headers=headers, stralign='right') return tabulate(node_summaries, headers=headers, stralign="right")
import torch
from typing import List
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins import builtins
import operator import operator
from copy import deepcopy from typing import List
import torch
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
def apply(*args, **kwargs): def apply(*args, **kwargs):
...@@ -24,16 +21,16 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi ...@@ -24,16 +21,16 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
origin_node_sharding_spec_dict = {} origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector strategies_vector = node.strategies_vector
setattr(node, 'best_strategy', strategies_vector[strategy_index]) setattr(node, "best_strategy", strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec) setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec)
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec
# apply the sharding spec of parameters # apply the sharding spec of parameters
for node in nodes: for node in nodes:
if node.op == 'call_module': if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target) target_module = node.graph.owning_module.get_submodule(node.target)
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {}) origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
setattr(target_module.weight, 'sharding_spec', origin_sharding_spec) setattr(target_module.weight, "sharding_spec", origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.input_shardings[1] target_weight_sharding_spec = node.best_strategy.input_shardings[1]
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
apply(target_module.weight, target_weight_sharding_spec) apply(target_module.weight, target_weight_sharding_spec)
...@@ -51,10 +48,10 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi ...@@ -51,10 +48,10 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
# add above dicts into graph # add above dicts into graph
for node in nodes: for node in nodes:
if node.op != 'placeholder': if node.op != "placeholder":
with mod_graph.inserting_before(node): with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
break break
return sharding_spec_convert_dict, origin_node_sharding_spec_dict return sharding_spec_convert_dict, origin_node_sharding_spec_dict
...@@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): ...@@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
node_to_index_dict = {} node_to_index_dict = {}
index = 0 index = 0
for node in nodes: for node in nodes:
if node.target == 'sharding_spec_convert_dict': if node.target == "sharding_spec_convert_dict":
input_dict_node = node input_dict_node = node
continue continue
if node.target == 'origin_node_sharding_spec_dict': if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node origin_dict_node = node
continue continue
if not hasattr(node, 'best_strategy'): if not hasattr(node, "best_strategy"):
continue continue
node_to_index_dict[node] = index node_to_index_dict[node] = index
index += 1 index += 1
...@@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): ...@@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
# add shape consistency apply function into graph # add shape consistency apply function into graph
for node in nodes: for node in nodes:
if not hasattr(node, 'best_strategy'): if not hasattr(node, "best_strategy"):
continue continue
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
origin_spec_node = mod_graph.create_node('call_function', origin_spec_node = mod_graph.create_node(
operator.getitem, "call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node])
args=(origin_dict_node, node_to_index_dict[node])) )
with mod_graph.inserting_after(origin_spec_node): with mod_graph.inserting_after(origin_spec_node):
set_sharding_spec_node = mod_graph.create_node('call_function', set_sharding_spec_node = mod_graph.create_node(
builtins.setattr, "call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node)
args=(node, 'sharding_spec', origin_spec_node)) )
for user_node in node.strategies_vector.successor_nodes: for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node) node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node): with mod_graph.inserting_before(user_node):
input_specs_node = mod_graph.create_node('call_function', input_specs_node = mod_graph.create_node(
operator.getitem, "call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node])
args=(input_dict_node, node_to_index_dict[node])) )
with mod_graph.inserting_before(user_node): with mod_graph.inserting_before(user_node):
sharding_spec_node = mod_graph.create_node('call_function', sharding_spec_node = mod_graph.create_node(
operator.getitem, "call_function", operator.getitem, args=(input_specs_node, node_index)
args=(input_specs_node, node_index)) )
with mod_graph.inserting_before(user_node): with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node)) shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node))
return gm return gm
...@@ -109,13 +109,13 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -109,13 +109,13 @@ class MetaInfoProp(torch.fx.Interpreter):
return TensorMetadata(None, None, False, None, 0, False) return TensorMetadata(None, None, False, None, 0, False)
tensor_meta = tree_map(extract_tensor_meta, result) tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta n.meta["tensor_meta"] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future # TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0))) setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0)))
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0)) setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0))
setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0)) setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0))
n.meta['type'] = type(result) n.meta["type"] = type(result)
# retain the autograd graph # retain the autograd graph
for param in self.module.parameters(): for param in self.module.parameters():
...@@ -125,7 +125,7 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -125,7 +125,7 @@ class MetaInfoProp(torch.fx.Interpreter):
# Main Node running APIs # Main Node running APIs
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``placeholder`` node. Note that this is stateful: Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over ``Interpreter`` maintains an internal iterator over
...@@ -146,7 +146,7 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -146,7 +146,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return super().placeholder(target, args, kwargs), GraphInfo() return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``get_attr`` node. Will retrieve an attribute Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``. value from the ``Module`` hierarchy of ``self.module``.
...@@ -165,7 +165,7 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -165,7 +165,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return super().get_attr(target, args, kwargs), GraphInfo() return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_function`` node with meta tensor and return the result and its meta profile. Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
...@@ -184,7 +184,7 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -184,7 +184,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_function(target)(*args, **kwargs) return profile_function(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_method`` node with meta tensor and return the result and its meta profile. Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
...@@ -202,7 +202,7 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -202,7 +202,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_method(target)(*args, **kwargs) return profile_method(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_module`` node with meta tensor and return the result and its meta profile. Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
...@@ -224,7 +224,7 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -224,7 +224,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_module(submod)(*args, **kwargs) return profile_module(submod)(*args, **kwargs)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute an ``output`` node. This really just retrieves Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it. the value referenced by the ``output`` node and returns it.
...@@ -240,7 +240,7 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -240,7 +240,7 @@ class MetaInfoProp(torch.fx.Interpreter):
result (Any): The argument value that was retrieved result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
""" """
if hasattr(args[0], '_tensor'): if hasattr(args[0], "_tensor"):
return args[0], GraphInfo(fwd_in=[args[0]._tensor]) return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True) return args[0], GraphInfo(save_fwd_in=True)
...@@ -257,7 +257,7 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -257,7 +257,7 @@ class MetaInfoProp(torch.fx.Interpreter):
""" """
return super().run(*args) return super().run(*args)
def summary(self, unit: str = 'MB') -> str: def summary(self, unit: str = "MB") -> str:
""" """
Summarizes the memory and FLOPs statistics of the `GraphModule` in Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module tabular format. Note that this API requires the ``tabulate`` module
...@@ -267,9 +267,11 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -267,9 +267,11 @@ class MetaInfoProp(torch.fx.Interpreter):
try: try:
from tabulate import tabulate from tabulate import tabulate
except ImportError: except ImportError:
print("`summary` relies on the library `tabulate`, " print(
"which could not be found on this machine. Run `pip " "`summary` relies on the library `tabulate`, "
"install tabulate` to install the library.") "which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
...@@ -278,10 +280,10 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -278,10 +280,10 @@ class MetaInfoProp(torch.fx.Interpreter):
def mem_repr(mem: int) -> str: def mem_repr(mem: int) -> str:
unit_divisor_map = { unit_divisor_map = {
'kb': 1024, "kb": 1024,
'mb': 1024**2, "mb": 1024**2,
'gb': 1024**3, "gb": 1024**3,
'tb': 1024**4, "tb": 1024**4,
} }
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
...@@ -292,35 +294,37 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -292,35 +294,37 @@ class MetaInfoProp(torch.fx.Interpreter):
for node in self.module.graph.nodes: for node in self.module.graph.nodes:
node: Node node: Node
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node) accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
node_summaries.append([ node_summaries.append(
node.op, [
str(node), node.op,
flops_repr(node.meta['fwd_flop']), str(node),
flops_repr(node.meta['bwd_flop']), flops_repr(node.meta["fwd_flop"]),
mem_repr(accumulate_size), flops_repr(node.meta["bwd_flop"]),
mem_repr(calculate_fwd_in(node)), mem_repr(accumulate_size),
mem_repr(calculate_fwd_out(node)), mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_tmp(node)), mem_repr(calculate_fwd_out(node)),
mem_repr(node.meta['bwd_mem_out']), mem_repr(calculate_fwd_tmp(node)),
mem_repr(node.meta['bwd_mem_tmp']), mem_repr(node.meta["bwd_mem_out"]),
]) mem_repr(node.meta["bwd_mem_tmp"]),
]
)
# Use the ``tabulate`` library to create a well-formatted table # Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information # presenting our summary information
headers: List[str] = [ headers: List[str] = [
'Op type', "Op type",
'Op', "Op",
'Forward FLOPs', "Forward FLOPs",
'Backward FLOPs', "Backward FLOPs",
'Accumulated Memory', "Accumulated Memory",
'FWD_IN', "FWD_IN",
'FWD_OUT', "FWD_OUT",
'FWD_TMP', "FWD_TMP",
'BWD_OUT', "BWD_OUT",
'BWD_TMP', "BWD_TMP",
] ]
return tabulate(node_summaries, headers=headers, stralign='right') return tabulate(node_summaries, headers=headers, stralign="right")
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None: def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
...@@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: ...@@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
Returns: Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo. torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
""" """
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
interp = MetaInfoProp(gm.to(device)) interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args) args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs) kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
interp.propagate(*args, **kwargs) interp.propagate(*args, **kwargs)
if verbose: if verbose:
interp.summary(unit) interp.summary(unit)
gm.to('cpu') gm.to("cpu")
del interp del interp
return gm return gm
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from packaging import version from packaging import version
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split
from colossalai.fx.passes.meta_info_prop import TensorMetadata from colossalai.fx.passes.meta_info_prop import TensorMetadata
...@@ -13,9 +12,9 @@ from colossalai.fx.passes.split_module import Partition ...@@ -13,9 +12,9 @@ from colossalai.fx.passes.split_module import Partition
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]): def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
''' """
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future. This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
''' """
mod_graph = gm.graph mod_graph = gm.graph
valid_children_size = 0 valid_children_size = 0
valid_children = [] valid_children = []
...@@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti ...@@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti
part_index += 1 part_index += 1
pp_size -= 1 pp_size -= 1
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile() gm.recompile()
return gm return gm
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule): def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
''' """
This pass will be used in gpt2 test, only a part of changes may be added into This pass will be used in gpt2 test, only a part of changes may be added into
split_with_split_nodes_pass, and it will be deprecated in future. split_with_split_nodes_pass, and it will be deprecated in future.
''' """
part_idx = 0 part_idx = 0
def eliminate_unused_placeholders(gm): def eliminate_unused_placeholders(gm):
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.op == 'placeholder': if node.op == "placeholder":
if not len(node.users): if not len(node.users):
gm.graph.erase_node(node) gm.graph.erase_node(node)
gm.recompile() gm.recompile()
return gm return gm
def refill_outputs_and_placeholders(gm, next_partition_placeholders): def refill_outputs_and_placeholders(gm, next_partition_placeholders):
''' """
This method is used to eliminate the outputs in previous partition which is unused in next partition. This method is used to eliminate the outputs in previous partition which is unused in next partition.
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel. In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
to partition 1 and partition 2. However, in single direction linked list, we need to do so. to partition 1 and partition 2. However, in single direction linked list, we need to do so.
''' """
output_type = None output_type = None
output_args = [] output_args = []
non_output_list = [] non_output_list = []
new_placeholder_list = [] new_placeholder_list = []
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.op == 'output': if node.op == "output":
if isinstance(node.args[0], (tuple, list)): if isinstance(node.args[0], (tuple, list)):
output_type = node.args[0].__class__ output_type = node.args[0].__class__
output_args.extend([n.name for n in node.args[0]]) output_args.extend([n.name for n in node.args[0]])
...@@ -114,7 +113,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) ...@@ -114,7 +113,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
continue continue
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.op == 'placeholder': if node.op == "placeholder":
new_placeholder_list.append(node.name) new_placeholder_list.append(node.name)
if output_type is not None: if output_type is not None:
gm.graph.output(output_type(output_args)) gm.graph.output(output_type(output_args))
...@@ -125,7 +124,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) ...@@ -125,7 +124,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
def split_callback(n: torch.fx.Node): def split_callback(n: torch.fx.Node):
nonlocal part_idx nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split): if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1 part_idx += 1
return part_idx return part_idx
...@@ -134,7 +133,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) ...@@ -134,7 +133,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
for name, submodule in split_mod.named_modules(): for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule): if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes: for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split): if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node) submodule.graph.erase_node(node)
submodule.recompile() submodule.recompile()
split_submodules.append(submodule) split_submodules.append(submodule)
...@@ -200,13 +199,12 @@ def split_module_for_gpt2_test( ...@@ -200,13 +199,12 @@ def split_module_for_gpt2_test(
_gen_all_ancestors_set(node) _gen_all_ancestors_set(node)
for n in list(all_ancestors): for n in list(all_ancestors):
if n.op != 'placeholder' and n._fx_partition > partition_name: if n.op != "placeholder" and n._fx_partition > partition_name:
n._fx_partition = partition_name n._fx_partition = partition_name
def record_cross_partition_use(def_node: torch.fx.node.Node, def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
use_node: Optional[torch.fx.node.Node]): # noqa: B950 def_partition_name = getattr(def_node, "_fx_partition", None)
def_partition_name = getattr(def_node, '_fx_partition', None) use_partition_name = getattr(use_node, "_fx_partition", None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name: if def_partition_name != use_partition_name:
# if 'tensor_meta' in def_node.meta: # if 'tensor_meta' in def_node.meta:
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']): # if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
...@@ -237,7 +235,7 @@ def split_module_for_gpt2_test( ...@@ -237,7 +235,7 @@ def split_module_for_gpt2_test(
if node.op in ["placeholder"]: if node.op in ["placeholder"]:
continue continue
if node.op == 'output': if node.op == "output":
# partition_name = str(split_callback(node)) # partition_name = str(split_callback(node))
# def _set_output_args_partition(n, partition_name): # def _set_output_args_partition(n, partition_name):
# n._fx_partition = partition_name # n._fx_partition = partition_name
...@@ -252,12 +250,12 @@ def split_module_for_gpt2_test( ...@@ -252,12 +250,12 @@ def split_module_for_gpt2_test(
partitions[partition_name] = partition = Partition(partition_name) partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name) partition.node_names.append(node.name)
origin_partition_name = getattr(node, '_fx_partition', None) origin_partition_name = getattr(node, "_fx_partition", None)
if origin_partition_name is None: if origin_partition_name is None:
node._fx_partition = partition_name node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies # find partitions with no dependencies
root_partitions: List[str] = [] root_partitions: List[str] = []
...@@ -287,7 +285,7 @@ def split_module_for_gpt2_test( ...@@ -287,7 +285,7 @@ def split_module_for_gpt2_test(
# Transform nodes and collect targets for partition's submodule # Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes: for node in m.graph.nodes:
if hasattr(node, '_fx_partition'): if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition] partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule # swap out old graph nodes in kw/args with references to new nodes in this submodule
...@@ -295,26 +293,24 @@ def split_module_for_gpt2_test( ...@@ -295,26 +293,24 @@ def split_module_for_gpt2_test(
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
if node.op not in ['call_module', 'get_attr']: if node.op not in ["call_module", "get_attr"]:
target = node.target target = node.target
else: else:
target_atoms = node.target.split('.') target_atoms = node.target.split(".")
target_attr = m target_attr = m
for atom in target_atoms: for atom in target_atoms:
if not hasattr(target_attr, atom): if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!') raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom) target_attr = getattr(target_attr, atom)
# target = target_atoms[-1] # target = target_atoms[-1]
target = '_'.join(target_atoms) target = "_".join(target_atoms)
partition.targets[target] = target_attr partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple) assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict) assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op, new_node = partition.graph.create_node(
target=target, op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name
args=gathered_args, )
kwargs=gathered_kwargs,
name=node.name)
new_node.meta = node.meta.copy() new_node.meta = node.meta.copy()
partition.environment[node] = new_node partition.environment[node] = new_node
...@@ -323,14 +319,14 @@ def split_module_for_gpt2_test( ...@@ -323,14 +319,14 @@ def split_module_for_gpt2_test(
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes: for node in m.graph.nodes:
if node.op == 'placeholder': if node.op == "placeholder":
if version.parse(torch.__version__) < version.parse('1.11.0'): if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else: else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, base_mod_env[node.name] = base_mod_graph.placeholder(
type_expr=node.type, node.name, type_expr=node.type, default_value=default_value
default_value=default_value) )
base_mod_env[node.name].meta = node.meta.copy() base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again: # Do some things iterating over the partitions in topological order again:
...@@ -344,13 +340,14 @@ def split_module_for_gpt2_test( ...@@ -344,13 +340,14 @@ def split_module_for_gpt2_test(
# Set correct output values # Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals) partition.graph.output(output_vals)
# Construct GraphModule for this partition # Construct GraphModule for this partition
submod_name = f'submod_{partition_name}' submod_name = f"submod_{partition_name}"
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
partition.graph) # noqa: B950 partition.targets, partition.graph
) # noqa: B950
# Emit call in base graph to this submodule # Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
...@@ -358,14 +355,14 @@ def split_module_for_gpt2_test( ...@@ -358,14 +355,14 @@ def split_module_for_gpt2_test(
# Unpack multiple return values from submodule # Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val) output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs): for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else: else:
if not partition.outputs: if not partition.outputs:
continue continue
base_mod_env[list(partition.outputs)[0]] = output_val base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes: for node in m.graph.nodes:
if node.op == 'output': if node.op == "output":
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
...@@ -9,8 +9,19 @@ from colossalai.legacy.tensor.distspec import ShardSpec ...@@ -9,8 +9,19 @@ from colossalai.legacy.tensor.distspec import ShardSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [ ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, torch.add,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout operator.add,
torch.abs,
torch.cos,
torch.exp,
torch.mul,
operator.mul,
operator.floordiv,
operator.truediv,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
] ]
...@@ -72,7 +83,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc ...@@ -72,7 +83,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# traverse the graph to look for consecutive linear layers # traverse the graph to look for consecutive linear layers
is_linear_module = False is_linear_module = False
if node.op == 'call_module': if node.op == "call_module":
# look for the linear layer # look for the linear layer
module = node.graph.owning_module.get_submodule(node.target) module = node.graph.owning_module.get_submodule(node.target)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
...@@ -82,31 +93,31 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc ...@@ -82,31 +93,31 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# it means the first linear has been found and the current module # it means the first linear has been found and the current module
# is the second linear # is the second linear
# set the current linear module to be row-sharded # set the current linear module to be row-sharded
annotation_record['row'] = module annotation_record["row"] = module
for shard_type, module in annotation_record.items(): for shard_type, module in annotation_record.items():
# add row sharding spec # add row sharding spec
if shard_type == 'row': if shard_type == "row":
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size]) dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D) comp_spec = ComputeSpec(ComputePattern.TP1D)
setattr(module.weight, 'pg', process_group) setattr(module.weight, "pg", process_group)
setattr(module.weight, 'dist_spec', dist_spec) setattr(module.weight, "dist_spec", dist_spec)
setattr(module.weight, 'comp_spec', comp_spec) setattr(module.weight, "comp_spec", comp_spec)
elif shard_type == 'col': elif shard_type == "col":
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D) weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False weight_comp_spec.output_replicate = False
setattr(module.weight, 'pg', process_group) setattr(module.weight, "pg", process_group)
setattr(module.weight, 'dist_spec', weight_dist_spec) setattr(module.weight, "dist_spec", weight_dist_spec)
setattr(module.weight, 'comp_spec', weight_comp_spec) setattr(module.weight, "comp_spec", weight_comp_spec)
if module.bias is not None: if module.bias is not None:
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D) bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False bias_comp_spec.output_replicate = False
setattr(module.bias, 'pg', process_group) setattr(module.bias, "pg", process_group)
setattr(module.bias, 'dist_spec', bias_dist_spec) setattr(module.bias, "dist_spec", bias_dist_spec)
setattr(module.bias, 'comp_spec', bias_comp_spec) setattr(module.bias, "comp_spec", bias_comp_spec)
start_tracking = False start_tracking = False
annotation_record.clear() annotation_record.clear()
else: else:
...@@ -114,16 +125,16 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc ...@@ -114,16 +125,16 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# it means the current layer is the first linear # it means the current layer is the first linear
# set the linear layer to be col-sharded # set the linear layer to be col-sharded
start_tracking = True start_tracking = True
annotation_record['col'] = module annotation_record["col"] = module
if start_tracking and not is_linear_module: if start_tracking and not is_linear_module:
# check against the white list # check against the white list
# if non-element wise op is found, we reset the tracking # if non-element wise op is found, we reset the tracking
if node.op == 'call_module': if node.op == "call_module":
module = node.graph.owning_module.get_submodule(node.target) module = node.graph.owning_module.get_submodule(node.target)
if module.__class__ not in ELEMENTWISE_MODULE_OP: if module.__class__ not in ELEMENTWISE_MODULE_OP:
start_tracking = False start_tracking = False
elif node.op == 'call_function' or node.op == 'call_method': elif node.op == "call_function" or node.op == "call_method":
if node.target not in ELEMENTWISE_FUNC_OP: if node.target not in ELEMENTWISE_FUNC_OP:
start_tracking = False start_tracking = False
elif len(node.users.keys()) > 1: elif len(node.users.keys()) > 1:
......
...@@ -25,12 +25,14 @@ class Partition: ...@@ -25,12 +25,14 @@ class Partition:
self.targets: Dict[str, Any] = {} self.targets: Dict[str, Any] = {}
def __repr__(self) -> str: def __repr__(self) -> str:
return f"name: {self.name},\n" \ return (
f" nodes: {self.node_names},\n" \ f"name: {self.name},\n"
f" inputs: {self.inputs},\n" \ f" nodes: {self.node_names},\n"
f" outputs: {self.outputs},\n" \ f" inputs: {self.inputs},\n"
f" partitions dependent on: {self.partitions_dependent_on},\n" \ f" outputs: {self.outputs},\n"
f" partitions dependent on: {self.partitions_dependent_on},\n"
f" partition dependents: {self.partition_dependents}" f" partition dependents: {self.partition_dependents}"
)
# Creates subgraphs out of main graph # Creates subgraphs out of main graph
...@@ -117,10 +119,9 @@ def split_module( ...@@ -117,10 +119,9 @@ def split_module(
partitions: Dict[str, Partition] = {} partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {} orig_nodes: Dict[str, torch.fx.node.Node] = {}
def record_cross_partition_use(def_node: torch.fx.node.Node, def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
use_node: Optional[torch.fx.node.Node]): # noqa: B950 def_partition_name = getattr(def_node, "_fx_partition", None)
def_partition_name = getattr(def_node, '_fx_partition', None) use_partition_name = getattr(use_node, "_fx_partition", None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name: if def_partition_name != use_partition_name:
if def_partition_name is not None: if def_partition_name is not None:
def_partition = partitions[def_partition_name] def_partition = partitions[def_partition_name]
...@@ -134,7 +135,7 @@ def split_module( ...@@ -134,7 +135,7 @@ def split_module(
if def_partition_name is not None: if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name) use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None) def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None) use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name: if def_partition_name != use_partition_name:
...@@ -161,7 +162,7 @@ def split_module( ...@@ -161,7 +162,7 @@ def split_module(
if node.op in ["placeholder"]: if node.op in ["placeholder"]:
continue continue
if node.op == 'output': if node.op == "output":
if merge_output: if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev)) torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else: else:
...@@ -178,7 +179,7 @@ def split_module( ...@@ -178,7 +179,7 @@ def split_module(
node._fx_partition = partition_name node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies # find partitions with no dependencies
root_partitions: List[str] = [] root_partitions: List[str] = []
...@@ -208,7 +209,7 @@ def split_module( ...@@ -208,7 +209,7 @@ def split_module(
# Transform nodes and collect targets for partition's submodule # Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes: for node in m.graph.nodes:
if hasattr(node, '_fx_partition'): if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition] partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule # swap out old graph nodes in kw/args with references to new nodes in this submodule
...@@ -216,25 +217,24 @@ def split_module( ...@@ -216,25 +217,24 @@ def split_module(
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
if node.op not in ['call_module', 'get_attr']: if node.op not in ["call_module", "get_attr"]:
target = node.target target = node.target
else: else:
target_atoms = node.target.split('.') target_atoms = node.target.split(".")
target_attr = m target_attr = m
for atom in target_atoms: for atom in target_atoms:
if not hasattr(target_attr, atom): if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!') raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom) target_attr = getattr(target_attr, atom)
# target = target_atoms[-1] # target = target_atoms[-1]
target = '_'.join(target_atoms) target = "_".join(target_atoms)
partition.targets[target] = target_attr partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple) assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict) assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op, new_node = partition.graph.create_node(
target=target, op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs
args=gathered_args, )
kwargs=gathered_kwargs)
new_node.meta = node.meta.copy() new_node.meta = node.meta.copy()
partition.environment[node] = new_node partition.environment[node] = new_node
...@@ -243,14 +243,14 @@ def split_module( ...@@ -243,14 +243,14 @@ def split_module(
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes: for node in m.graph.nodes:
if node.op == 'placeholder': if node.op == "placeholder":
if version.parse(torch.__version__) < version.parse('1.11.0'): if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type) base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else: else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, base_mod_env[node.name] = base_mod_graph.placeholder(
type_expr=node.type, node.target, type_expr=node.type, default_value=default_value
default_value=default_value) )
base_mod_env[node.name].meta = node.meta.copy() base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again: # Do some things iterating over the partitions in topological order again:
...@@ -264,13 +264,14 @@ def split_module( ...@@ -264,13 +264,14 @@ def split_module(
# Set correct output values # Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals) partition.graph.output(output_vals)
# Construct GraphModule for this partition # Construct GraphModule for this partition
submod_name = f'submod_{partition_name}' submod_name = f"submod_{partition_name}"
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
partition.graph) # noqa: B950 partition.targets, partition.graph
) # noqa: B950
# Emit call in base graph to this submodule # Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
...@@ -278,15 +279,15 @@ def split_module( ...@@ -278,15 +279,15 @@ def split_module(
# Unpack multiple return values from submodule # Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val) output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs): for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else: else:
if not partition.outputs: if not partition.outputs:
continue continue
base_mod_env[list(partition.outputs)[0]] = output_val base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes: for node in m.graph.nodes:
if node.op == 'output': if node.op == "output":
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
for partition_name in sorted_partitions: for partition_name in sorted_partitions:
partition = partitions[partition_name] partition = partitions[partition_name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment