Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
......@@ -7,8 +7,13 @@ import fabric
from .hostinfo import HostInfo, HostInfoList
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
send_conn: mp_connection.Connection, env: dict) -> None:
def run_on_host(
hostinfo: HostInfo,
workdir: str,
recv_conn: mp_connection.Connection,
send_conn: mp_connection.Connection,
env: dict,
) -> None:
"""
Use fabric connection to execute command on local or remote hosts.
......@@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
finish = False
env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()])
# keep listening until exit
while not finish:
# receive cmd
cmds = recv_conn.recv()
if cmds == 'exit':
if cmds == "exit":
# exit from the loop
finish = True
break
......@@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
else:
# execute on the remote machine
fab_conn.run(cmds, hide=False)
send_conn.send('success')
send_conn.send("success")
except Exception as e:
click.echo(
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
)
send_conn.send('failure')
send_conn.send("failure")
# shutdown
send_conn.send("finish")
......@@ -96,8 +101,7 @@ class MultiNodeRunner:
cmd (str): the command to execute
"""
assert hostinfo.hostname in self.master_send_conns, \
f'{hostinfo} is not found in the current connections'
assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections"
conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd)
......@@ -107,14 +111,14 @@ class MultiNodeRunner:
"""
for hostname, conn in self.master_send_conns.items():
conn.send('exit')
conn.send("exit")
def recv_from_all(self) -> dict:
"""
Receive messages from all hosts
Returns:
msg_from_node (dict): a dictionry which contains messages from each node
msg_from_node (dict): a dictionary which contains messages from each node
"""
msg_from_node = dict()
......
......@@ -12,7 +12,7 @@ from .hostinfo import HostInfo, HostInfoList
from .multinode_runner import MultiNodeRunner
# Constants that define our syntax
NODE_SEP = ','
NODE_SEP = ","
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
......@@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
exit()
with open(hostfile_path, 'r') as fd:
with open(hostfile_path, "r") as fd:
device_pool = HostInfoList()
for line in fd.readlines():
line = line.strip()
if line == '':
if line == "":
# skip empty lines
continue
......@@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
"""Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples:
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
......@@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
'''
"""
# Ensure include/exclude are mutually exclusive
if include_str and exclude_str:
......@@ -136,16 +136,16 @@ def get_launch_command(
for k, v in arg_dict.items():
if v:
ret.append(f'--{k}={v}')
ret.append(f"--{k}={v}")
else:
ret.append(f'--{k}')
ret.append(f"--{k}")
return ret
if extra_launch_args:
extra_launch_args_dict = dict()
for arg in extra_launch_args.split(','):
if '=' in arg:
k, v = arg.split('=')
for arg in extra_launch_args.split(","):
if "=" in arg:
k, v = arg.split("=")
extra_launch_args_dict[k] = v
else:
extra_launch_args_dict[arg] = None
......@@ -154,19 +154,23 @@ def get_launch_command(
extra_launch_args = dict()
torch_version = version.parse(torch.__version__)
assert torch_version.major == 1
assert torch_version.major >= 1
if torch_version.minor < 9:
if torch_version.major == 1 and torch_version.minor < 9:
# torch distributed launch cmd with torch < 1.9
cmd = [
sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
f"--node_rank={node_rank}"
sys.executable,
"-m",
"torch.distributed.launch",
f"--nproc_per_node={nproc_per_node}",
f"--master_addr={master_addr}",
f"--master_port={master_port}",
f"--nnodes={num_nodes}",
f"--node_rank={node_rank}",
]
else:
# extra launch args for torch distributed launcher with torch >= 1.9
default_torchrun_rdzv_args = dict(rdzv_backend="c10d",
rdzv_endpoint=f"{master_addr}:{master_port}",
rdzv_id="colossalai-default-job")
default_torchrun_rdzv_args = dict(master_addr=master_addr, master_port=master_port)
# update rdzv arguments
for key in default_torchrun_rdzv_args.keys():
......@@ -174,19 +178,28 @@ def get_launch_command(
value = extra_launch_args.pop(key)
default_torchrun_rdzv_args[key] = value
if torch_version.minor < 10:
if torch_version.major == 1 and torch_version.minor == 9:
# torch distributed launch cmd with torch == 1.9
cmd = [
sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
sys.executable,
"-m",
"torch.distributed.run",
f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}",
f"--node_rank={node_rank}",
]
else:
# torch distributed launch cmd with torch > 1.9
cmd = [
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
"torchrun",
f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}",
f"--node_rank={node_rank}",
]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
cmd = ' '.join(cmd)
cmd = " ".join(cmd)
return cmd
......@@ -250,33 +263,39 @@ def launch_multi_processes(args: Config) -> None:
# run on local node if not hosts or hostfile is given
# add local node to host info list
active_device_pool = HostInfoList()
localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port)
active_device_pool.append(localhost_info)
# launch distributed processes
runner = MultiNodeRunner()
curr_path = os.path.abspath('.')
curr_path = os.path.abspath(".")
# collect current path env
env = dict()
for k, v in os.environ.items():
# do not support multi-line env var
if v and '\n' not in v:
if v and "\n" not in v:
env[k] = v
# establish remote connection
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
# overwrite master addr when num_nodes > 1 and not specified
if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1":
args.master_addr = active_device_pool.hostinfo_list[0].hostname
# execute distributed launching command
for node_id, hostinfo in enumerate(active_device_pool):
cmd = get_launch_command(master_addr=args.master_addr,
cmd = get_launch_command(
master_addr=args.master_addr,
master_port=args.master_port,
nproc_per_node=args.nproc_per_node,
user_script=args.user_script,
user_args=args.user_args,
node_rank=node_id,
num_nodes=len(active_device_pool),
extra_launch_args=args.extra_launch_args)
extra_launch_args=args.extra_launch_args,
)
runner.send(hostinfo=hostinfo, cmd=cmd)
# start training
......@@ -298,7 +317,7 @@ def launch_multi_processes(args: Config) -> None:
# receive the stop status
msg_from_node = runner.recv_from_all()
# printe node status
# print node status
click.echo("\n====== Stopping All Nodes =====")
for hostname, msg in msg_from_node.items():
click.echo(f"{hostname}: {msg}")
......
from .device_mesh_manager import DeviceMeshManager
from .dist_coordinator import DistCoordinator
from .process_group_manager import ProcessGroupManager
from .process_group_mesh import ProcessGroupMesh
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']
__all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"]
......@@ -10,13 +10,14 @@ from colossalai.device.device_mesh import DeviceMesh
@dataclass
class DeviceMeshInfo:
'''
"""
This class is used to store the information used to initialize the device mesh.
Args:
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
'''
"""
physical_ids: List[int]
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None
......@@ -24,16 +25,18 @@ class DeviceMeshInfo:
if self.mesh_shape is not None:
world_size = len(self.physical_ids)
mesh_shape_numel = torch.Size(self.mesh_shape).numel()
assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}'
assert (
world_size == mesh_shape_numel
), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}"
def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
'''
"""
This method is used to initialize the device mesh.
Args:
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
'''
"""
# parse the device mesh info
physical_devices = device_mesh_info.physical_ids
physical_mesh = torch.tensor(physical_devices)
......@@ -73,7 +76,7 @@ class DeviceMeshManager:
self.device_mesh_store[name] = device_mesh
return device_mesh
else:
raise ValueError(f'Device mesh {name} already exists.')
raise ValueError(f"Device mesh {name} already exists.")
def get(self, name: str) -> DeviceMesh:
"""
......@@ -88,7 +91,7 @@ class DeviceMeshManager:
if name in self.device_mesh_store:
return self.device_mesh_store[name]
else:
raise ValueError(f'Device mesh {name} does not exist.')
raise ValueError(f"Device mesh {name} does not exist.")
def destroy(self, name: str) -> None:
"""
......@@ -103,7 +106,7 @@ class DeviceMeshManager:
dist.destroy_process_group(pg)
del self.device_mesh_store[name]
else:
raise ValueError(f'Device mesh {name} does not exist.')
raise ValueError(f"Device mesh {name} does not exist.")
def destroy_all(self):
"""
......
......@@ -20,14 +20,16 @@ class DistCoordinator(metaclass=SingletonMeta):
- master: the process with rank 0
- node master: the process with local rank 0 on the current node
Example:
>>> from colossalai.cluster.dist_coordinator import DistCoordinator
>>> coordinator = DistCoordinator()
>>>
>>> if coordinator.is_master():
>>> do_something()
>>>
>>> coordinator.print_on_master('hello world')
```python
from colossalai.cluster.dist_coordinator import DistCoordinator
coordinator = DistCoordinator()
if coordinator.is_master():
do_something()
coordinator.print_on_master('hello world')
```
Attributes:
rank (int): the rank of the current process
......@@ -36,12 +38,13 @@ class DistCoordinator(metaclass=SingletonMeta):
"""
def __init__(self):
assert dist.is_initialized(
), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
assert (
dist.is_initialized()
), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
self._local_rank = os.environ.get('LOCAL_RANK', -1)
self._local_rank = os.environ.get("LOCAL_RANK", -1)
@property
def rank(self) -> int:
......@@ -59,7 +62,9 @@ class DistCoordinator(metaclass=SingletonMeta):
"""
Assert that the local rank is set. This is often passed by launchers such as torchrun.
"""
assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
assert (
self.local_rank >= 0
), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
def is_master(self, process_group: ProcessGroup = None) -> bool:
"""
......@@ -128,11 +133,13 @@ class DistCoordinator(metaclass=SingletonMeta):
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
Example:
>>> from colossalai.cluster import DistCoordinator
>>> dist_coordinator = DistCoordinator()
>>> with dist_coordinator.priority_execution():
>>> dataset = CIFAR10(root='./data', download=True)
```python
from colossalai.cluster import DistCoordinator
dist_coordinator = DistCoordinator()
with dist_coordinator.priority_execution():
dataset = CIFAR10(root='./data', download=True)
```
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
......@@ -171,19 +178,19 @@ class DistCoordinator(metaclass=SingletonMeta):
"""
A function wrapper that only executes the wrapped function on the master process (rank 0).
Example:
>>> from colossalai.cluster import DistCoordinator
>>> dist_coordinator = DistCoordinator()
>>>
>>> @dist_coordinator.on_master_only()
>>> def print_on_master(msg):
>>> print(msg)
```python
from colossalai.cluster import DistCoordinator
dist_coordinator = DistCoordinator()
@dist_coordinator.on_master_only()
def print_on_master(msg):
print(msg)
```
"""
is_master = self.is_master(process_group)
# define an inner functiuon
# define an inner function
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master:
......
......@@ -19,7 +19,7 @@ class ProcessGroupManager:
def __init__(self):
self.pg_store = dict()
def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup:
"""
Get a process group by name. If the process group does not exist, it will be created.
......@@ -36,7 +36,7 @@ class ProcessGroupManager:
self.pg_store[name] = pg
return pg
else:
raise ValueError(f'Process group {name} already exists.')
raise ValueError(f"Process group {name} already exists.")
def get(self, name: str) -> ProcessGroup:
"""
......@@ -51,7 +51,7 @@ class ProcessGroupManager:
if name in self.pg_store:
return self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')
raise ValueError(f"Process group {name} does not exist.")
def destroy(self, name: str) -> None:
"""
......@@ -64,7 +64,7 @@ class ProcessGroupManager:
dist.destroy_process_group(self.pg_store[name])
del self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')
raise ValueError(f"Process group {name} does not exist.")
def destroy_all(self) -> None:
"""
......
import itertools
from functools import reduce
from operator import mul
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup
def prod(nums: List[int]) -> int:
"""Product of a list of numbers.
Args:
nums (List[int]): A list of numbers.
Returns:
int: The product of the numbers.
"""
return reduce(mul, nums)
class ProcessGroupMesh:
"""A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method.
It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation.
We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process.
For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``.
Args:
*size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size.
Attributes:
shape (Tuple[int, ...]): The shape of the process group mesh.
rank (int): The rank of the current process.
"""
def __init__(self, *size: int) -> None:
assert dist.is_initialized(), "Please initialize torch.distributed first."
assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size."
self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
@property
def shape(self) -> Tuple[int, ...]:
return self._shape
@property
def rank(self) -> int:
return self._rank
def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
"""Get the size of the process group mesh.
Args:
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
Returns:
Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh.
"""
if dim is None:
return self._shape
else:
return self._shape[dim]
def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
"""Get the coordinate of the process group mesh.
Args:
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
Returns:
Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh.
"""
if dim is None:
return self._coord
else:
return self._coord[dim]
@staticmethod
def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""Convert a rank to a coordinate.
Args:
rank (int): Rank to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
Returns:
Tuple[int, ...]: Coordinate of the rank.
"""
return np.unravel_index(rank, shape)
@staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
Args:
coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
mode (Optional[str]): The mode for numpy.ravel_multi_index.
Returns:
int: Rank of the coordinate.
"""
assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
Args:
ranks_in_group (List[int]): Ranks in the process group.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group with the given ranks.
"""
ranks_in_group = sorted(ranks_in_group)
if tuple(ranks_in_group) not in self._group_to_ranks:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
return self._ranks_to_group[tuple(ranks_in_group)]
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
"""Get the ranks in the given process group. The process group must be created by this class.
Args:
group (ProcessGroup): The process group.
Returns:
List[int]: Ranks in the process group.
"""
return list(self._group_to_ranks[group])
@staticmethod
def get_coords_along_axis(
base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis.
Args:
base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on.
axis (int): Axis along which the coordinates are generated.
indices_at_axis (List[int]): Indices at the axis.
Returns:
List[Tuple[int, ...]]: Coordinates along the axis.
"""
coords_in_group = []
for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group
def create_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
reduced_shape = list(self._shape)
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
reduced_shape[axis] = 1
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
group = self.get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group
def get_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group]
from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward,
send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
recv_forward, recv_backward)
from .ring import ring_forward
from .utils import send_obj_meta, recv_obj_meta
__all__ = [
'all_gather',
'reduce_scatter',
'all_reduce',
'broadcast',
'reduce',
'send_forward',
'send_forward_recv_forward',
'send_forward_backward_recv_forward_backward',
'send_backward',
'send_backward_recv_backward',
'send_backward_recv_forward',
'send_forward_recv_backward',
'recv_backward',
'recv_forward',
'ring_forward',
'send_obj_meta',
'recv_obj_meta',
]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
# initializer
INITIALIZER_MAPPING = {
'data': 'Initializer_Data',
'tensor': 'Initializer_Tensor',
'pipeline': 'Initializer_Pipeline',
'embedding': 'Initializer_Embedding',
'1d': 'Initializer_1D',
'2d': 'Initializer_2D',
'2.5d': 'Initializer_2p5D',
'3d': 'Initializer_3D',
'sequence': 'Initializer_Sequence',
'model': 'Initializer_Model',
'moe': 'Initializer_Moe'
}
# 3D parallelism groups
INPUT_GROUP_3D = 'input_group_3d'
WEIGHT_GROUP_3D = 'weight_group_3d'
OUTPUT_GROUP_3D = 'output_group_3d'
INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d'
OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d'
# Attributes of tensor parallel parameters
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
from .config import Config, ConfigException
from .parallel_context import ParallelContext
from .parallel_mode import ParallelMode
from .moe_context import MOE_CONTEXT
from .process_group_initializer import *
from .random import *
# from .moe_context import MOE_CONTEXT
__all__ = [
"Config",
"ConfigException",
]
......@@ -5,6 +5,7 @@ import inspect
import sys
from importlib.machinery import SourceFileLoader
from pathlib import Path
from colossalai.logging import get_dist_logger
......@@ -41,7 +42,7 @@ class Config(dict):
self.__setattr__(key, value)
def update(self, config):
assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.'
assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
for k, v in config.items():
self._add_item(k, v)
return self
......@@ -66,11 +67,11 @@ class Config(dict):
elif isinstance(filename, Path):
filepath = filename.absolute()
assert filepath.exists(), f'{filename} is not found, please check your configuration path'
assert filepath.exists(), f"{filename} is not found, please check your configuration path"
# check extension
extension = filepath.suffix
assert extension == '.py', 'only .py files are supported'
assert extension == ".py", "only .py files are supported"
# import the config as module
remove_path = False
......@@ -86,13 +87,13 @@ class Config(dict):
config = Config()
for k, v in module.__dict__.items():
if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v):
if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
continue
else:
config._add_item(k, v)
logger = get_dist_logger()
logger.debug('variables which starts with __, is a module or class declaration are omitted in config file')
logger.debug("variables which starts with __, is a module or class declaration are omitted in config file")
# remove module
del sys.modules[module_name]
......
......@@ -3,21 +3,19 @@ from typing import Tuple
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor import ProcessGroup
from colossalai.legacy.tensor import ProcessGroup
def _check_sanity():
from colossalai.core import global_context as gpc
from colossalai.legacy.core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or "
"pipeline parallel at present.")
raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
class MoeParallelInfo:
"""Moe parallelism information, storing parallel sizes and groups.
"""
"""Moe parallelism information, storing parallel sizes and groups."""
def __init__(self, ep_size: int, dp_size: int):
_check_sanity()
......@@ -61,10 +59,12 @@ class MoeContext(metaclass=SingletonMeta):
self.world_size = dist.get_world_size()
from colossalai.core import global_context as gpc
self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
assert self.world_size % self.max_ep_size == 0, \
"Maximum expert parallel size must be a factor of the number of GPUs"
from colossalai.legacy.core import global_context as gpc
self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
assert (
self.world_size % self.max_ep_size == 0
), "Maximum expert parallel size must be a factor of the number of GPUs"
self.min_dp_size = self.world_size // self.max_ep_size
# Enabling kernel optimization may raise error in some cases
......@@ -72,6 +72,7 @@ class MoeContext(metaclass=SingletonMeta):
self.use_kernel_optim = use_kernel_optim
from .random import moe_set_seed
moe_set_seed(seed)
self.has_setup = True
......@@ -92,8 +93,10 @@ class MoeContext(metaclass=SingletonMeta):
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
assert gt_flag or lt_flag, (
"Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts
......
......@@ -16,6 +16,7 @@ class SingletonMeta(type):
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
else:
assert len(args) == 0 and len(
kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
assert (
len(args) == 0 and len(kwargs) == 0
), f"{cls.__name__} is a singleton class and a instance has been created."
return cls._instances[cls]
from .alpha_beta_profiler import AlphaBetaProfiler
from .calc_pipeline_strategy import alpa_dp
__all__ = ['AlphaBetaProfiler', 'alpa_dp']
__all__ = ["AlphaBetaProfiler", "alpa_dp"]
......@@ -13,7 +13,7 @@ FRAMEWORK_LATENCY = 0
class AlphaBetaProfiler:
'''
"""
Profile alpha and beta value for a given device list.
Usage:
......@@ -27,17 +27,19 @@ class AlphaBetaProfiler:
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
'''
"""
def __init__(self,
def __init__(
self,
physical_devices: List[int],
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
ctype: str = 'a',
ctype: str = "a",
warmup: int = 5,
repeat: int = 25,
latency_iters: int = 5,
homogeneous_tolerance: float = 0.1):
'''
homogeneous_tolerance: float = 0.1,
):
"""
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
......@@ -45,7 +47,7 @@ class AlphaBetaProfiler:
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
'''
"""
self.physical_devices = physical_devices
self.ctype = ctype
self.world_size = len(physical_devices)
......@@ -123,7 +125,7 @@ class AlphaBetaProfiler:
return (None, None)
def profile_latency(self, process_group, pg_handler):
'''
"""
This function is used to profile the latency of the given process group with a series of bytes.
Args:
......@@ -132,7 +134,7 @@ class AlphaBetaProfiler:
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
'''
"""
latency_list = []
for i in range(self.latency_iters):
nbytes = int(BYTE << i)
......@@ -148,26 +150,26 @@ class AlphaBetaProfiler:
return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
'''
"""
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
'''
"""
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
return bandwidth
def profile_ab(self):
'''
"""
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
'''
"""
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
global_pg_handler = dist.new_group(self.physical_devices)
dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
......@@ -197,7 +199,7 @@ class AlphaBetaProfiler:
dist.broadcast_object_list(broadcast_list, src=process_group[0])
alpha_beta_dict[process_group] = tuple(broadcast_list)
# add symmetry pair to the apha_beta_dict
# add symmetry pair to the alpha_beta_dict
symmetry_ab_dict = {}
for process_group, alpha_beta_pair in alpha_beta_dict.items():
symmetry_process_group = (process_group[1], process_group[0])
......@@ -208,7 +210,7 @@ class AlphaBetaProfiler:
return alpha_beta_dict
def search_best_logical_mesh(self):
'''
"""
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
......@@ -232,19 +234,19 @@ class AlphaBetaProfiler:
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
'''
"""
def _power_of_two(integer):
return integer & (integer - 1) == 0
def _detect_homogeneous_device(alpha_beta_dict):
'''
"""
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
'''
"""
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
for process_group, (_, beta) in alpha_beta_dict.items():
if homogeneous_device_dict is None:
......@@ -254,7 +256,8 @@ class AlphaBetaProfiler:
match_beta = None
for beta_value in homogeneous_device_dict.keys():
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
1 - self.homogeneous_tolerance):
1 - self.homogeneous_tolerance
):
match_beta = beta_value
break
......@@ -267,9 +270,9 @@ class AlphaBetaProfiler:
return homogeneous_device_dict
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
'''
"""
This function is used to check whether the homogeneous_group contains all physical devices.
'''
"""
flatten_mesh = []
for process_group in homogeneous_group:
flatten_mesh.extend(process_group)
......@@ -277,9 +280,9 @@ class AlphaBetaProfiler:
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
'''
"""
This function is used to construct the largest ring in the homogeneous_group for each rank.
'''
"""
# Construct the ring
ring = []
ranks_in_ring = []
......@@ -300,7 +303,9 @@ class AlphaBetaProfiler:
check_rank = check_rank_list.pop()
for process_group in homogeneous_group:
if check_rank in process_group:
rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
rank_to_append = (
process_group[0] if process_group[1] == check_rank else process_group[1]
)
if rank_to_append not in ring_for_rank:
stable_status = False
rank_to_check_list.append(rank_to_append)
......@@ -314,7 +319,7 @@ class AlphaBetaProfiler:
assert _power_of_two(self.world_size)
power_of_two = int(math.log2(self.world_size))
median = power_of_two // 2
balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median))
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
balanced_logical_mesh = []
for row_index in range(row_size):
......@@ -348,7 +353,7 @@ class AlphaBetaProfiler:
return best_logical_mesh
def extract_alpha_beta_for_device_mesh(self):
'''
"""
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
......@@ -360,7 +365,7 @@ class AlphaBetaProfiler:
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
'''
"""
best_logical_mesh = self.search_best_logical_mesh()
first_axis = [row[0] for row in best_logical_mesh]
......@@ -381,7 +386,7 @@ class AlphaBetaProfiler:
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
mesh_alpha = [first_latency, second_latency]
# The beta values have been enlarged by 1e10 times temporarilly because the computation cost
# The beta values have been enlarged by 1e10 times temporarily because the computation cost
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]
......
......@@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
while i <= num_devices_per_host:
i *= 2
p += 1
assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
f"while now num_devices_per_host = {num_devices_per_host}")
assert pow(2, p) == num_devices_per_host, (
"Only supports the cases where num_devices_per_host is power of two, "
f"while now num_devices_per_host = {num_devices_per_host}"
)
if mode == "alpa":
for i in range(p + 1):
submesh_choices.append((1, pow(2, i)))
......@@ -24,8 +26,9 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
return submesh_choices
def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
best_configs):
def alpa_dp_impl(
num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs
):
"""Implementation of Alpa DP for pipeline strategy
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
......@@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
for i in range(num_layers, k, -1):
stage_cost = compute_cost[k, i, m]
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
if stage_cost <= max_stage_cost and new_cost < f[s, k, d]:
f[s, k, d] = new_cost
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
......@@ -75,24 +78,20 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
res = []
while current_s > 0 and current_layer < num_layers and current_devices > 0:
next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices]
assert next_start_layer != -1 and current_devices != -1
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
current_s -= 1
current_layer = next_start_layer
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
assert current_s == 0 and current_layer == num_layers and current_devices == 0
return total_cost, res
def alpa_dp(num_layers,
num_devices,
num_microbatches,
submesh_choices,
num_autosharding_configs,
compute_cost,
gap=1e-6):
def alpa_dp(
num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6
):
"""Alpa auto stage dynamic programming.
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
......@@ -101,8 +100,12 @@ def alpa_dp(num_layers,
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
num_autosharding_configs), "Cost shape wrong."
assert np.shape(compute_cost) == (
num_layers,
num_layers,
len(submesh_choices),
num_autosharding_configs,
), "Cost shape wrong."
all_possible_stage_costs = np.sort(np.unique(compute_cost))
best_cost = np.inf
best_solution = None
......@@ -117,8 +120,9 @@ def alpa_dp(num_layers,
break
if max_stage_cost - last_max_stage_cost < gap:
continue
cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
max_stage_cost, best_configs)
cost, solution = alpa_dp_impl(
num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs
)
if cost < best_cost:
best_cost = cost
best_solution = solution
......
......@@ -3,11 +3,19 @@
with some changes. """
import operator
from dataclasses import dataclass
from functools import reduce
from typing import List, Tuple
from typing import Dict, List, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@dataclass
class ProcessGroupContainer:
process_group: ProcessGroup
ranks: List[int]
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
......@@ -27,223 +35,491 @@ class DeviceMesh:
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
"""
def __init__(self,
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
def __init__(
self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
need_flatten: bool = True):
self.physical_mesh_id = physical_mesh_id
device: str = "cuda",
):
# ============================
# Physical & Logical Mesh IDs
# ============================
self._physical_mesh_id = physical_mesh_id
assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
# logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, (
"Only one of mesh_shape and logical_mesh_id can be specified."
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
)
if logical_mesh_id is None:
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
self._mesh_shape = mesh_shape
self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape
self._mesh_shape = self._logical_mesh_id.shape
# ensure two things:
# 1. logical and physical mesh IDs should contain the same elements
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
assert torch.equal(
torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert (
torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert (
torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# map global rank into logical rank
self.convert_map = {}
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
# ===============================================
# coefficient for alpha-beta communication model
# alpha is latency and beta is bandwidth
# ===============================================
# if the values are not provided, we assume they are 1 for simplicity
if mesh_alpha is None:
mesh_alpha = [1] * len(self.mesh_shape)
mesh_alpha = [1] * len(self._mesh_shape)
if mesh_beta is None:
mesh_beta = [1] * len(self.mesh_shape)
mesh_beta = [1] * len(self._mesh_shape)
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
self.init_process_group = init_process_group
self.need_flatten = need_flatten
if self.init_process_group:
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
if self.need_flatten and self._logical_mesh_id.dim() > 1:
self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
# self.mesh_beta)
# ensure the alpha and beta have the same shape
assert len(self.mesh_alpha) == len(
self.mesh_beta
), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
# =========================
# Device for Process Group
# =========================
self._device = device
self._dist_backend = self._DIST_BACKEND[device]
# =========================
# Process Group Management
# =========================
# the _global_to_local_rank_mapping is structured as follows
# {
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
# }
self._global_to_local_rank_mapping = dict()
self._init_global_to_logical_rank_mapping(
mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
)
# create process group
self._process_group_dict = {}
self._ranks_in_the_process_group = {}
self._global_rank_of_current_process = None
self._is_initialized = False
# attribute used to indicate whether this object
# is created using DeviceMesh.from_process_group
# this attribute can be used to do some check in methods
# such get_process_group as no global rank information
# is known if created with from_process_group
self._is_init_from_process_group = False
# initialize process group if specified
self._init_ranks_in_the_same_group()
self._init_process_group = init_process_group
if init_process_group:
self.init_logical_process_group()
@property
def shape(self):
return self.mesh_shape
def shape(self) -> torch.Size:
"""
Return the shape of the logical mesh.
"""
return self._mesh_shape
@property
def num_devices(self):
return reduce(operator.mul, self.physical_mesh_id.shape, 1)
def num_devices(self) -> int:
"""
Return the number of devices contained in the device mesh.
"""
return reduce(operator.mul, self._physical_mesh_id.shape, 1)
@property
def logical_mesh_id(self):
def logical_mesh_id(self) -> torch.Tensor:
"""
Return the logical mesh id.
"""
return self._logical_mesh_id
def __deepcopy__(self, memo):
@property
def is_initialized(self) -> bool:
"""
Return whether the process group is initialized.
"""
return self._is_initialized
@staticmethod
def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh":
"""
Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method
will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication.
Args:
process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh.
If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects,
the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh.
Returns:
DeviceMesh: the device mesh instance.
"""
def _get_device_by_backend(process_group):
"""
Get the device type given a process group's backend.
"""
backend = dist.get_backend(process_group)
for _device, _backend in DeviceMesh._DIST_BACKEND.items():
if _backend == backend:
return _device
return None
if isinstance(process_group, ProcessGroup):
process_group = [process_group]
# get mesh shape
mesh_shape = [dist.get_world_size(pg) for pg in process_group]
# get device
device_list = [_get_device_by_backend(pg) for pg in process_group]
# make sure all devices are the same
assert all(
[device == device_list[0] for device in device_list]
), "All devices should be the same, please check your input process groups are created with the same distributed backend."
# create a fake physical mesh id
# as we only get the process group associated with the current process,
# we cannot get the global ranks for all processes in the mesh
# therefore, we only use this fake physical mesh id to create the device mesh
# and will remove this fake physical mesh id later
fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1))
# create the device mesh
device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0])
# hack the device attribute
device_mesh._physical_mesh_id = None
device_mesh._logical_mesh_id = None
device_mesh._global_rank_of_current_process = dist.get_rank()
device_mesh._is_initialized = False
device_mesh._process_group_dict = {
device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)}
}
return device_mesh
def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
"""
Return the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._process_group_dict[global_rank][axis]
def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
"""
Return the process groups for all axes.
Args:
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._process_group_dict[global_rank]
def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
"""
Return the ranks in the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._ranks_in_the_process_group[global_rank][axis]
def __deepcopy__(self, memo) -> "DeviceMesh":
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
if k != "_process_group_dict":
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
# process group cannot be copied
# thus, we share them directly
setattr(result, k, v)
return result
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
def _init_global_to_logical_rank_mapping(
self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
) -> Dict[int, List[int]]:
"""
flatten_mesh_shape_size = len(self.mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group,
need_flatten=False)
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
Args:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
tensor (torch.Tensor): the tensor that contains the logical mesh ids.
index_list (List[int])
def _global_rank_to_logical_rank_map(self, tensor, index_list):
'''
This method is a helper function to build convert_map recursively.
'''
Returns:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
The value is a list of integers and each integer represents the local rank in the indexed axis.
"""
for index, inner_tensor in enumerate(tensor):
# index means the local rank in the current axis
# inner_tensor refers to the processes with the same local rank
if inner_tensor.numel() == 1:
self.convert_map[int(inner_tensor)] = index_list + [index]
# if the inner_tensor only has one element, it means that
# it already reaches the last axis
# we append its local_rank in the last axis to the index_list
# and assign to the mapping
# the value of the mapping is the the local rank at the indexed axis of the device mesh
mapping[int(inner_tensor)] = index_list + [index]
else:
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
# we recursively go into the function until we reach the last axis
# meanwhile, we should add the local rank in the current axis in the index_list
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
def create_process_groups_for_logical_mesh(self):
'''
def init_logical_process_group(self):
"""
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
process_groups_dict = {}
check_duplicate_list = []
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
"""
# sanity check
assert (
dist.is_initialized
), "The torch.distributed should be initialized before calling init_logical_process_group"
assert (
not self._is_initialized
), "The logical process group has been initialized, do not call init_logical_process_group twice"
# update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank()
duplicate_check_list = []
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
for axis, process_group in process_groups.items():
if axis not in process_groups_dict:
process_groups_dict[axis] = []
if process_group not in check_duplicate_list:
check_duplicate_list.append(process_group)
process_group_handler = dist.new_group(process_group)
process_groups_dict[axis].append((process_group, process_group_handler))
return process_groups_dict
def global_rank_to_logical_rank(self, rank):
return self.convert_map[rank]
def global_rank_to_process_groups_with_logical_rank(self, rank):
'''
Give a global rank and return all logical process groups of this rank.
for example:
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.global_rank_to_process_groups_with_logical_rank(0))
output:
# key is axis name
# value is a list of logical ranks in same axis with rank 0
{0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]}
'''
process_groups = {}
for d in range(self.logical_mesh_id.dim()):
for replacer in range(self.logical_mesh_id.shape[d]):
if d not in process_groups:
process_groups[d] = []
process_group_member = self.convert_map[rank].copy()
process_group_member[d] = replacer
process_groups[d].append(process_group_member)
return process_groups
def global_rank_to_process_groups_with_global_rank(self, rank):
'''
Give a global rank and return all process groups of this rank.
for example:
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
# find the other ranks which are in the same process group as global_rank
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
# skip duplicated process group creation
if ranks_in_same_group in duplicate_check_list:
continue
# create the process group
pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
# keep this process group in the process_groups_dict
for rank in ranks_in_same_group:
if rank not in self._process_group_dict:
self._process_group_dict[rank] = dict()
self._process_group_dict[rank][axis] = pg_handler
# update the init flag
# we only allow init for once
self._is_initialized = True
def _init_ranks_in_the_same_group(self):
"""
This method is used to initialize the ranks_in_the_same_group dictionary.
"""
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
# find the other ranks which are in the same process group as global_rank
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
# create dict for each rank
if global_rank not in self._process_group_dict:
self._ranks_in_the_process_group[global_rank] = dict()
# keep this process group in the process_groups_dict
self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
"""
Return the local rank of the given global rank in the logical device mesh.
Args:
rank (int): the global rank in the logical device mesh.
axis (int): the axis of the logical device mesh.
"""
if self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
local_ranks = self._global_to_local_rank_mapping[rank]
if axis:
return local_ranks[axis]
else:
return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank):
"""
Give a global rank and return all global ranks involved in its associated process group in each axis.
Example:
```python
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# logical mesh will look like
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.global_rank_to_process_groups_with_global_rank(0))
output:
print(device_mesh.collate_global_ranks_in_same_process_group(0))
# key is axis name
# value is a list of global ranks in same axis with rank 0
{0: [0, 4, 8, 12], 1: [0, 1, 2, 3]}
'''
logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank)
process_groups = {}
for dim, logical_ranks in logical_process_groups.items():
process_groups[dim] = []
for logical_rank in logical_ranks:
for g_rank, l_rank in self.convert_map.items():
if l_rank == logical_rank:
process_groups[dim].append(g_rank)
return process_groups
# output will look like
# {
0: [0, 4, 8, 12],
1: [0, 1, 2, 3]
# }
"""
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
# for self._global_to_local_rank_mapping
# the key is the global rank
# the value is the list of local ranks corresponding to the global rank with respect of different axes
# we can see the list of local ranks as the process coordinates for simplicity
# the key and value are all unique, therefore,
# we can also to use the coordinates to find the global rank
# =========================================================================
# Step 1
# find all the process_coordinates for processes in the same process group
# as the given global rank
# =========================================================================
# each
processes_in_the_same_process_group = {}
for dim in range(self.logical_mesh_id.dim()):
# iterate over the dimension size so that we can include all processes
# in the same process group in the given axis
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):
# if this dimension is not initialized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
processes_in_the_same_process_group[dim] = []
# get the local rank corresponding to the global rank
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
# replace the local rank in the given dimension with the
# local rank of the current process iterated
process_coordinates[dim] = _local_rank
processes_in_the_same_process_group[dim].append(process_coordinates)
# =================================================================
# Step 2
# Use local rank combination to find its corresponding global rank
# =================================================================
# the key of the dict is the axis
# the value is the list of global ranks which are in the same process group as the given global rank
global_pg_ranks = {}
for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
global_pg_ranks[dim] = []
for process_coordinates in coordinates_of_all_processes:
# find the global rank by local rank combination
for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
if process_coordinates == _process_coordinates:
global_pg_ranks[dim].append(_global_rank)
return global_pg_ranks
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
if self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
flatten_mesh_shape_size = len(self._mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(
self._physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group,
)
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.1)
return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
0.01)
return (
self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
+ 0.01
)
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.001)
return (
self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
)
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
class FlattenDeviceMesh(DeviceMesh):
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
super().__init__(physical_mesh_id,
mesh_shape,
mesh_alpha,
mesh_beta,
init_process_group=False,
need_flatten=False)
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
self.mesh_alpha = max(self.mesh_alpha)
self.mesh_beta = min(self.mesh_beta)
# Different from original process_groups_dict, rank_list is not stored
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
def create_process_numbers_for_logical_mesh(self):
'''
Build 1d DeviceMesh in column-major(0) and row-major(1)
for example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
num_devices = reduce(operator.mul, self.mesh_shape, 1)
process_numbers_dict = {}
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
return process_numbers_dict
def mix_gather_cost(self, num_bytes):
num_devices = reduce(operator.mul, self.mesh_shape, 1)
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
return (
self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
+ 0.001
)
......@@ -2,16 +2,14 @@ from typing import Callable
import torch
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 12:
META_COMPATIBILITY = False
elif TORCH_MAJOR == 1 and TORCH_MINOR == 12:
from . import _meta_regist_12
META_COMPATIBILITY = True
elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
from . import _meta_regist_13
META_COMPATIBILITY = True
elif TORCH_MAJOR == 2:
META_COMPATIBILITY = True
......@@ -36,7 +34,7 @@ def compatibility(is_backward_compatible: bool = False) -> Callable:
else:
def wrapper(*args, **kwargs):
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}")
return wrapper
......
......@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Union
import torch
from torch.utils._pytree import tree_map
......@@ -16,13 +16,11 @@ meta_table = {}
def register_meta(op, register_dispatcher=True):
def wrapper(f):
def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
......@@ -48,7 +46,6 @@ def meta_conv(
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
......@@ -125,7 +122,8 @@ def meta_conv(
kernel_size[i],
stride[i],
output_padding_list[i],
))
)
)
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
......@@ -164,17 +162,37 @@ def meta_conv(
@register_meta(aten._convolution.default)
def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
*extra_args):
def meta_conv_1(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args,
):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
def meta_conv_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta")
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
......@@ -208,7 +226,6 @@ def meta_cuda_rnn(
batch_sizes,
dropout_state,
):
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
......@@ -224,8 +241,11 @@ def meta_cuda_rnn(
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = ([mini_batch, seq_length, out_size *
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
out_shape = (
[mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
......@@ -242,18 +262,20 @@ def meta_cuda_rnn(
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor,
def meta_cudnn_rnn_backward(
input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs):
**kwargs,
):
print(input, weight, hx, cx)
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_hx = torch.empty_like(hx)
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta')
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta")
return grad_input, grad_weight, grad_hx, grad_cx
......@@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
running_mean = torch.empty((n_input), device="meta")
running_var = torch.empty((n_input), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask):
def meta_bn_backward(
dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
train,
eps,
output_mask,
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
......@@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
reserve = torch.empty((0), dtype=torch.uint8, device='meta')
running_mean = torch.empty((n_input), device="meta")
running_var = torch.empty((n_input), device="meta")
reserve = torch.empty((0), dtype=torch.uint8, device="meta")
return output, running_mean, running_var, reserve
......@@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
def meta_cudnn_bn_backward(
dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
eps,
reserve,
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
......@@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((bs, n_input, 1), device='meta')
running_var = torch.empty((bs, n_input, 1), device='meta')
running_mean = torch.empty((bs, n_input, 1), device="meta")
running_var = torch.empty((bs, n_input, 1), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
def meta_ln_backward(
dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias)
......@@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
assert index.dtype in [
torch.long,
torch.int8,
torch.bool,
], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
assert (
index.shape[j] == self.shape[k + j]
), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
......@@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices):
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return torch.empty((num_weights, grad_output.size(-1)),
def meta_embedding_dense_backward(
grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
):
return torch.empty(
(num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
layout=grad_output.layout,
)
# ============================== Dropout ===========================================
......
from typing import Any, Callable, Dict, Iterable, List, Tuple
from typing import Any, Dict, Iterable, List, Tuple
import torch
......@@ -18,6 +18,7 @@ try:
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import (
......@@ -32,12 +33,13 @@ except:
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = False
if CODEGEN_AVAILABLE:
__all__ = ['ActivationCheckpointCodeGen']
__all__ = ["ActivationCheckpointCodeGen"]
else:
__all__ = ['python_code_with_activation_checkpoint']
__all__ = ["python_code_with_activation_checkpoint"]
def _gen_saved_tensors_hooks():
......@@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]):
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_nodes = []
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
if 'activation_checkpoint' in node.meta:
act_ckpt_label = node.meta['activation_checkpoint']
if "activation_checkpoint" in node.meta:
act_ckpt_label = node.meta["activation_checkpoint"]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
......@@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and not 'activation_checkpoint' in node.meta:
elif current_region is not None and not "activation_checkpoint" in node.meta:
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
......@@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
act_offload_label = node.meta['activation_offload']
if "activation_offload" in node.meta and isinstance(node.meta["activation_offload"], Iterable):
act_offload_label = node.meta["activation_offload"]
if current_region == None:
current_region = act_offload_label
......@@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
"""
Generate the checkpoint function call code text
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
outputs = ", ".join(output_vars)
inputs = ", ".join(input_vars)
return f"{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
......@@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
Returns:
bool
"""
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], list):
return node.meta['activation_checkpoint'][check_idx] == None
if "activation_checkpoint" in node.meta:
if isinstance(node.meta["activation_checkpoint"], list):
return node.meta["activation_checkpoint"][check_idx] == None
else:
return False
else:
......@@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region = None
for idx, node in enumerate(nodes):
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], int):
act_ckpt_label = node.meta['activation_checkpoint']
if "activation_checkpoint" in node.meta:
if isinstance(node.meta["activation_checkpoint"], int):
act_ckpt_label = node.meta["activation_checkpoint"]
else:
act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
act_ckpt_label = node.meta["activation_checkpoint"][check_idx]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
......@@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
return ckpt_regions
def emit_ckpt_func(body,
ckpt_func,
node_list: List[Node],
emit_node_func,
delete_unused_value_func,
level=0,
in_ckpt=False):
def emit_ckpt_func(
body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, level=0, in_ckpt=False
):
"""Emit ckpt function in nested way
Args:
body: forward code, in recursive calls, this part will be checkpoint
......@@ -321,17 +318,17 @@ def emit_ckpt_func(body,
inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method
if isinstance(node_list[0].meta['activation_checkpoint'], int):
label = node_list[0].meta['activation_checkpoint']
if isinstance(node_list[0].meta["activation_checkpoint"], int):
label = node_list[0].meta["activation_checkpoint"]
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
ckpt_func.append(f"{ckpt_fn_def}\n")
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = node_list[0].meta.get('activation_offload', False)
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)
......@@ -340,12 +337,12 @@ def emit_ckpt_func(body,
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
label = "_".join([str(idx) for idx in node_list[0].meta["activation_checkpoint"][: level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch
if level + 1 < len(node_list[0].meta['activation_checkpoint']):
if level + 1 < len(node_list[0].meta["activation_checkpoint"]):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
......@@ -358,38 +355,45 @@ def emit_ckpt_func(body,
break
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
delete_unused_value_func, level + 1, True)
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(
ckpt_func,
ckpt_func_buffer,
ckpt_node_list,
emit_node_func,
delete_unused_value_func,
level + 1,
True,
)
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
usage = ' ' + usage
usage = " " + usage
body.append(usage)
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
usage = ' ' + usage
usage = " " + usage
body.append(usage)
......@@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
offload_node_list = node_list[start:end + 1]
offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
......@@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# process ckpt_regions
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
......@@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
if within_offload_region:
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
......@@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions):
ckpt_node_list = node_list[start:end + 1]
ckpt_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(ckpt_node_list)
input_vars.append(inputs)
output_vars.append(outputs)
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
offload_node_list = node_list[start:end + 1]
offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
......@@ -523,11 +527,11 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# append code text to body
for idx, node in enumerate(node_list):
# if this is the first node of the ckpt region
# append the ckpt function defition
# append the ckpt function definition
if idx in start_idx:
label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
ckpt_func.append(f'{ckpt_fn_def}\n')
ckpt_func.append(f"{ckpt_fn_def}\n")
within_ckpt_region = True
if idx in offload_starts:
......@@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# NOTE: currently we separate body and ckpt_func definition
if within_ckpt_region:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
elif within_offload_region:
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
......@@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate return statement
label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label])
return_statement = f' {return_statement}\n\n'
return_statement = f" {return_statement}\n\n"
ckpt_func.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
if 'activation_offload' in node_list[start_node_idx].meta:
activation_offload = node_list[start_node_idx].meta['activation_offload']
if "activation_offload" in node_list[start_node_idx].meta:
activation_offload = node_list[start_node_idx].meta["activation_offload"]
else:
activation_offload = False
......@@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if input_node.op != "placeholder":
non_leaf_input = 1
for user in input_node.users:
if 'activation_checkpoint' in user.meta:
if user.meta['activation_checkpoint'] == label:
if "activation_checkpoint" in user.meta:
if user.meta["activation_checkpoint"] == label:
if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
......@@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
usage += '\n'
usage += "\n"
body.append(usage)
within_ckpt_region = False
......@@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if CODEGEN_AVAILABLE:
class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
......@@ -629,7 +632,7 @@ if CODEGEN_AVAILABLE:
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = ['']
maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
......@@ -662,16 +665,16 @@ if CODEGEN_AVAILABLE:
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return '()'
return "()"
typename = _type_repr(o)
if hasattr(o, '__origin__'):
if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, '__args__'):
if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
......@@ -690,19 +693,18 @@ if CODEGEN_AVAILABLE:
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
......@@ -728,90 +730,101 @@ if CODEGEN_AVAILABLE:
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == 'placeholder':
if user.op == "placeholder":
return
if user.op == 'output':
body.append('\n')
if user.op == "output":
body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f"; {to_delete_str}\n")
else:
body.append('\n')
body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n')
body.append(f"{repr(node)} = {raw_name}\n")
return
elif node.op == 'call_method':
elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})"
)
return
elif node.op == 'call_function':
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
body.append(
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier() and \
len(node.args) == 2:
if (
global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return
body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
if node.meta.get('is_wrapped', False):
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
)
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == 'call_module':
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return
elif node.op == 'get_attr':
elif node.op == "get_attr":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
elif node.op == 'output':
elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
raise NotImplementedError(f'node: {node.op} {node.target}')
raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
......@@ -820,13 +833,13 @@ if CODEGEN_AVAILABLE:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append('pass\n')
body.append("pass\n")
if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
wrap_stmts = ''
wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
......@@ -837,11 +850,11 @@ if CODEGEN_AVAILABLE:
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = ''.join(ckpt_func) + prologue
prologue = "".join(ckpt_func) + prologue
prologue = prologue
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
......@@ -861,7 +874,7 @@ else:
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = ['']
maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
......@@ -894,12 +907,12 @@ else:
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return '()'
return "()"
typename = _type_repr(o)
# This is a generic type, e.g. typing.List[torch.Tensor]
if hasattr(o, '__origin__'):
if hasattr(o, "__origin__"):
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
......@@ -934,84 +947,94 @@ else:
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == 'placeholder':
if user.op == "placeholder":
return
if user.op == 'output':
body.append('\n')
if user.op == "output":
body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f"; {to_delete_str}\n")
else:
body.append('\n')
body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n')
body.append(f"{repr(node)} = {raw_name}\n")
return
elif node.op == 'call_method':
elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})"
)
return
elif node.op == 'call_function':
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier() and \
len(node.args) == 2:
if (
global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return
body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
if node.meta.get('is_wrapped', False):
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
)
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == 'call_module':
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return
elif node.op == 'get_attr':
elif node.op == "get_attr":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
elif node.op == 'output':
elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
if self._pytree_info is None:
body.append(f'return {repr(node.args[0])}')
body.append(f"return {repr(node.args[0])}")
else:
body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)')
body.append(f"return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)")
return
raise NotImplementedError(f'node: {node.op} {node.target}')
raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
......@@ -1020,33 +1043,34 @@ else:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append('pass\n')
body.append("pass\n")
if self._pytree_info is not None:
orig_args = self._pytree_info.orig_args
has_orig_self = (orig_args[0] == 'self')
has_orig_self = orig_args[0] == "self"
if has_orig_self:
free_vars.insert(0, 'self')
free_vars.insert(0, "self")
if len(free_vars) > 0: # pytree has placeholders in it
body.insert(
0,
f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n")
f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n",
)
else:
orig_args = free_vars
if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
wrap_stmts = ''
wrap_stmts = ""
ckpt_func = ''.join(ckpt_func)
ckpt_func = "".join(ckpt_func)
# If the original function didn't have self as its first argument, we
# would have added it.
if len(orig_args) == 0 or orig_args[0] != 'self':
orig_args.insert(0, 'self')
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
if len(orig_args) == 0 or orig_args[0] != "self":
orig_args.insert(0, "self")
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment