Commit 7bc5a8e3 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents e6748d82 0f785cb1
from ._helper import (
add_seed,
get_current_mode,
get_seeds,
get_states,
moe_set_seed,
reset_seeds,
seed,
set_mode,
set_seed_states,
sync_states,
with_seed,
)
__all__ = [
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states',
'sync_states', 'moe_set_seed', 'reset_seeds'
]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import functools
from contextlib import contextmanager
import torch.cuda
from torch import Tensor
from .seed_manager import SeedManager
from ..parallel_mode import ParallelMode
_SEED_MANAGER = SeedManager()
def get_seeds():
"""Returns the seeds of the seed manager.
Returns:
dict: The seeds of the seed manager.
"""
return _SEED_MANAGER.seeds
def get_states(copy=False):
"""Returns the seed states of the seed manager.
Returns:
dict: The seed states of the seed manager.
"""
states = _SEED_MANAGER.seed_states
if copy:
new_states = dict()
for parallel_mode, state in states.items():
new_states[parallel_mode] = state.clone()
return new_states
else:
return _SEED_MANAGER.seed_states
def get_current_mode():
"""Returns the current mode of the seed manager.
Returns:
:class:`torch.ByteTensor`: The current mode of the seed manager.
"""
return _SEED_MANAGER.current_mode
def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
seed (int): The seed to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
_SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)
def set_mode(parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
_SEED_MANAGER.set_mode(parallel_mode)
def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
state (:class:`torch.Tensor`): the state to be set.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
"""
_SEED_MANAGER.set_state(parallel_mode, state)
def sync_states():
current_mode = get_current_mode()
current_states = torch.cuda.get_rng_state()
set_seed_states(current_mode, current_states)
@contextmanager
def seed(parallel_mode: ParallelMode):
""" A context for seed switch
Examples:
>>> with seed(ParallelMode.DATA):
>>> output = F.dropout(input)
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
try:
# set to new mode
current_mode = _SEED_MANAGER.current_mode
yield _SEED_MANAGER.set_mode(parallel_mode)
finally:
# recover
_SEED_MANAGER.set_mode(current_mode)
def with_seed(func, parallel_mode: ParallelMode):
"""
A function wrapper which executes the function with a specified seed.
Examples:
>>> # use with decorator
>>> @with_seed(ParallelMode.DATA)
>>> def forward(input):
>>> return F.dropout(input)
>>> out = forward(input)
>>> # OR use it inline
>>> def forward(input):
>>> return F.dropout(input)
>>> wrapper_forward = with_seed(forward, ParallelMode.DATA)
>>> out = wrapped_forward(input)
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# switch mode
current_mode = _SEED_MANAGER.current_mode
_SEED_MANAGER.set_mode(parallel_mode)
# exec func
out = func(*args, **kwargs)
# recover state
_SEED_MANAGER.set_mode(current_mode)
return out
return wrapper
def moe_set_seed(seed):
if torch.cuda.is_available():
from colossalai.core import global_context as gpc
global_rank = gpc.get_global_rank()
diff_seed = seed + global_rank
add_seed(ParallelMode.TENSOR, diff_seed, True)
print(f"moe seed condition: {global_rank} with tensor seed {diff_seed}", flush=True)
def reset_seeds():
_SEED_MANAGER.reset()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
class SeedManager:
"""This class is a manager of all random seeds involved in the system.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
def __init__(self):
self._current_mode = None
self._seeds = dict()
self._seed_states = dict()
@property
def current_mode(self):
return self._current_mode
@property
def seeds(self):
return self._seeds
@property
def seed_states(self):
return self._seed_states
def set_state(self, parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
state (:class:`torch.Tensor`): the state to be set.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
"""
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager'
self._seed_states[parallel_mode] = state
def set_mode(self, parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
"""
if self.current_mode:
# save the current state for current mode
self._seed_states[self._current_mode] = torch.cuda.get_rng_state()
# set the new state for new mode
self._current_mode = parallel_mode
torch.cuda.set_rng_state(self._seed_states[parallel_mode])
def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
seed (int): The seed to be added.
overwrite (bool, optional): Whether allows to overwrite the seed that has been set already
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.context.ParallelMode`
or the seed for `parallel_mode` has been added.
"""
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
if overwrite is False:
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
elif parallel_mode in self._seed_states:
print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True)
current_state = torch.cuda.get_rng_state()
torch.cuda.manual_seed(seed)
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
self._seeds[parallel_mode] = seed
torch.cuda.set_rng_state(current_state)
def reset(self):
self._current_mode = None
self._seeds = dict()
self._seed_states = dict()
class SingletonMeta(type):
"""
The Singleton class can be implemented in different ways in Python. Some
possible methods include: base class, decorator, metaclass. We will use the
metaclass because it is best suited for this purpose.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
"""
Possible changes to the value of the `__init__` argument do not affect
the returned instance.
"""
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
else:
assert len(args) == 0 and len(
kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
return cls._instances[cls]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.context.parallel_context import global_context
__all__ = ['global_context']
\ No newline at end of file
from .alpha_beta_profiler import AlphaBetaProfiler
from .calc_pipeline_strategy import alpa_dp
__all__ = ['AlphaBetaProfiler', 'alpa_dp']
import math
import time
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
GB = int((1 << 30))
BYTE = 4
FRAMEWORK_LATENCY = 0
class AlphaBetaProfiler:
'''
Profile alpha and beta value for a given device list.
Usage:
# Note: the environment of execution is supposed to be
# multi-process with multi-gpu in mpi style.
>>> physical_devices = [0, 1, 4, 5]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> ab_dict = profiler.alpha_beta_dict
>>> print(ab_dict)
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
'''
def __init__(self,
physical_devices: List[int],
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
ctype: str = 'a',
warmup: int = 5,
repeat: int = 25,
latency_iters: int = 5,
homogeneous_tolerance: float = 0.1):
'''
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
ctype: 'a' for all-reduce, 'b' for broadcast.
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
'''
self.physical_devices = physical_devices
self.ctype = ctype
self.world_size = len(physical_devices)
self.warmup = warmup
self.repeat = repeat
self.latency_iters = latency_iters
self.homogeneous_tolerance = homogeneous_tolerance
self.process_group_dict = None
self._init_profiling()
if alpha_beta_dict is None:
self.alpha_beta_dict = self.profile_ab()
else:
self.alpha_beta_dict = alpha_beta_dict
def _init_profiling(self):
# Create process group list based on its global rank
process_group_list = []
for f_index in range(self.world_size - 1):
for b_index in range(f_index + 1, self.world_size):
process_group_list.append((self.physical_devices[f_index], self.physical_devices[b_index]))
# Create process group dict which maps process group to its handler
process_group_dict = {}
for process_group in process_group_list:
pg_handler = dist.new_group(process_group)
process_group_dict[process_group] = pg_handler
self.process_group_dict = process_group_dict
def _profile(self, process_group, pg_handler, nbytes):
logger = get_dist_logger()
rank = dist.get_rank()
src_device_num = process_group[0]
world_size = len(process_group)
device = torch.cuda.current_device()
buf = torch.randn(nbytes // 4).to(device)
torch.cuda.synchronize()
# warmup
for _ in range(self.warmup):
if self.ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
elif self.ctype == "b":
dist.broadcast(buf, src=src_device_num, group=pg_handler)
torch.cuda.synchronize()
dist.barrier(group=pg_handler)
begin = time.perf_counter()
for _ in range(self.repeat):
if self.ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler)
elif self.ctype == "b":
dist.broadcast(buf, src=src_device_num, group=pg_handler)
torch.cuda.synchronize()
end = time.perf_counter()
dist.barrier(group=pg_handler)
if rank == src_device_num:
avg_time_s = (end - begin) / self.repeat - FRAMEWORK_LATENCY
alg_band = nbytes / avg_time_s
if self.ctype == "a":
# convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware.
bus_band = 2 * (world_size - 1) / world_size * alg_band
bus_band = alg_band
elif self.ctype == "b":
bus_band = alg_band
logger.info(
f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s"
)
return (avg_time_s, alg_band)
else:
# Just a placeholder
return (None, None)
def profile_latency(self, process_group, pg_handler):
'''
This function is used to profile the latency of the given process group with a series of bytes.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
'''
latency_list = []
for i in range(self.latency_iters):
nbytes = int(BYTE << i)
(t, _) = self._profile(process_group, pg_handler, nbytes)
latency_list.append(t)
if latency_list[0] is None:
latency = None
else:
median_index = math.floor(self.latency_iters / 2)
latency = latency_list[median_index]
return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
'''
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
'''
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
return bandwidth
def profile_ab(self):
'''
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
'''
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
global_pg_handler = dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
device = torch.cuda.current_device()
rank_max_nbytes = torch.cuda.mem_get_info(device)[0]
rank_max_nbytes = torch.tensor(rank_max_nbytes, device=device)
dist.all_reduce(rank_max_nbytes, op=dist.ReduceOp.MIN, group=pg_handler)
max_nbytes = min(int(1 * GB), int(GB << int(math.log2(rank_max_nbytes.item() / GB))))
return max_nbytes
for process_group, pg_handler in self.process_group_dict.items():
if rank not in process_group:
max_nbytes = None
alpha = None
bandwidth = None
else:
max_nbytes = get_max_nbytes(process_group, pg_handler)
alpha = self.profile_latency(process_group, pg_handler)
bandwidth = self.profile_bandwidth(process_group, pg_handler, maxbytes=max_nbytes)
if bandwidth is None:
beta = None
else:
beta = 1 / bandwidth
broadcast_list = [alpha, beta]
dist.broadcast_object_list(broadcast_list, src=process_group[0])
alpha_beta_dict[process_group] = tuple(broadcast_list)
# add symmetry pair to the apha_beta_dict
symmetry_ab_dict = {}
for process_group, alpha_beta_pair in alpha_beta_dict.items():
symmetry_process_group = (process_group[1], process_group[0])
symmetry_ab_dict[symmetry_process_group] = alpha_beta_pair
alpha_beta_dict.update(symmetry_ab_dict)
return alpha_beta_dict
def search_best_logical_mesh(self):
'''
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict
are homogeneous if the beta value is close enough.
2. Find the best homogeneous device group contains all the physical devices. The best homogeneous
device group means the lowest beta value in the groups which contains all the physical devices.
And the reason we require the group contains all the physical devices is that the devices not in
the group will decrease the bandwidth of the group.
3. If the best homogeneous device group is found, we will construct the largest ring for each device
based on the best homogeneous device group, and the best logical mesh will be the union of all the
rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for
4 devices.
Returns:
best_logical_mesh: The best logical mesh for the given device list.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
'''
def _power_of_two(integer):
return integer & (integer - 1) == 0
def _detect_homogeneous_device(alpha_beta_dict):
'''
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
'''
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
for process_group, (_, beta) in alpha_beta_dict.items():
if homogeneous_device_dict is None:
homogeneous_device_dict[beta] = []
homogeneous_device_dict[beta].append(process_group)
match_beta = None
for beta_value in homogeneous_device_dict.keys():
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
1 - self.homogeneous_tolerance):
match_beta = beta_value
break
if match_beta is not None:
homogeneous_device_dict[match_beta].append(process_group)
else:
homogeneous_device_dict[beta] = []
homogeneous_device_dict[beta].append(process_group)
return homogeneous_device_dict
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
'''
This function is used to check whether the homogeneous_group contains all physical devices.
'''
flatten_mesh = []
for process_group in homogeneous_group:
flatten_mesh.extend(process_group)
non_duplicated_flatten_mesh = set(flatten_mesh)
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
'''
This function is used to construct the largest ring in the homogeneous_group for each rank.
'''
# Construct the ring
ring = []
ranks_in_ring = []
for rank in self.physical_devices:
if rank in ranks_in_ring:
continue
stable_status = False
ring_for_rank = []
ring_for_rank.append(rank)
check_rank_list = [rank]
rank_to_check_list = []
while not stable_status:
stable_status = True
check_rank_list.extend(rank_to_check_list)
rank_to_check_list = []
for i in range(len(check_rank_list)):
check_rank = check_rank_list.pop()
for process_group in homogeneous_group:
if check_rank in process_group:
rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
if rank_to_append not in ring_for_rank:
stable_status = False
rank_to_check_list.append(rank_to_append)
ring_for_rank.append(rank_to_append)
ring.append(ring_for_rank)
ranks_in_ring.extend(ring_for_rank)
return ring
assert _power_of_two(self.world_size)
power_of_two = int(math.log2(self.world_size))
median = power_of_two // 2
balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
balanced_logical_mesh = []
for row_index in range(row_size):
balanced_logical_mesh.append([])
for column_index in range(column_size):
balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index])
homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict)
beta_list = [b for b in homogeneous_device_dict.keys()]
beta_list.sort()
beta_list.reverse()
homogeneous_types = len(beta_list)
best_logical_mesh = None
if homogeneous_types >= 2:
for _ in range(homogeneous_types - 1):
lowest_beta = beta_list.pop()
best_homogeneous_group = homogeneous_device_dict[lowest_beta]
# if the best homogeneous group contains all physical devices,
# we will build the logical device mesh based on it. Otherwise,
# we will check next level homogeneous group.
if _check_contain_all_devices(best_homogeneous_group):
# We choose the largest ring for each rank to maximum the best bus utilization.
best_logical_mesh = _construct_largest_ring(best_homogeneous_group)
break
if homogeneous_types == 1 or best_logical_mesh is None:
# in this case, we use balanced logical mesh as the best
# logical mesh.
best_logical_mesh = balanced_logical_mesh
return best_logical_mesh
def extract_alpha_beta_for_device_mesh(self):
'''
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
>>> print(mesh_alpha)
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
'''
best_logical_mesh = self.search_best_logical_mesh()
first_axis = [row[0] for row in best_logical_mesh]
second_axis = best_logical_mesh[0]
# init process group for both axes
first_axis_process_group = dist.new_group(first_axis)
second_axis_process_group = dist.new_group(second_axis)
# extract alpha and beta for both axes
def _extract_alpha_beta(pg, pg_handler):
latency = self.profile_latency(pg, pg_handler)
bandwidth = self.profile_bandwidth(pg, pg_handler)
broadcast_object = [latency, bandwidth]
dist.broadcast_object_list(broadcast_object, src=pg[0])
return broadcast_object
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
mesh_alpha = [first_latency, second_latency]
# The beta values have been enlarged by 1e10 times temporarilly because the computation cost
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]
return mesh_alpha, mesh_beta
from math import pow
import numpy as np
def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
submesh_choices = []
i = 1
p = -1
while i <= num_devices_per_host:
i *= 2
p += 1
assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
f"while now num_devices_per_host = {num_devices_per_host}")
if mode == "alpa":
for i in range(p + 1):
submesh_choices.append((1, pow(2, i)))
for i in range(2, num_hosts + 1):
submesh_choices.append((i, num_devices_per_host))
elif mode == "new":
for i in range(p // 2 + 1):
for j in range(i, p - i + 1):
submesh_choices.append((pow(2, i), pow(2, j)))
return submesh_choices
def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
best_configs):
"""Implementation of Alpa DP for pipeline strategy
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
Arguments:
num_layers: K
num_devices: N*M
num_microbatches: B
submesh_choices: List[(n_i,m_i)]
compute_cost: t_intra
"""
# For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32)
f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32)
f[0, num_layers, 0] = 0
for s in range(1, num_layers + 1):
for k in range(num_layers - 1, -1, -1):
for d in range(1, num_devices + 1):
for m, submesh in enumerate(submesh_choices):
n_submesh_devices = np.prod(np.array(submesh))
if n_submesh_devices <= d:
# TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete.
# if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]:
# ...
for i in range(num_layers, k, -1):
stage_cost = compute_cost[k, i, m]
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
f[s, k, d] = new_cost
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
best_s = -1
best_total_cost = np.inf
for s in range(1, num_layers + 1):
if f[s, 0, num_devices] < best_total_cost:
best_s = s
best_total_cost = f[s, 0, num_devices]
if np.isinf(best_total_cost):
return np.inf, None
total_cost = f[best_s, 0, num_devices] + (num_microbatches - 1) * f_stage_max[best_s, 0, num_devices]
current_s = best_s
current_layer = 0
current_devices = num_devices
res = []
while current_s > 0 and current_layer < num_layers and current_devices > 0:
next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
assert next_start_layer != -1 and current_devices != -1
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
current_s -= 1
current_layer = next_start_layer
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
return total_cost, res
def alpa_dp(num_layers,
num_devices,
num_microbatches,
submesh_choices,
num_autosharding_configs,
compute_cost,
gap=1e-6):
"""Alpa auto stage dynamic programming.
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments:
submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
num_autosharding_configs), "Cost shape wrong."
all_possible_stage_costs = np.sort(np.unique(compute_cost))
best_cost = np.inf
best_solution = None
last_max_stage_cost = 0.0
# TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost?
# In dp_impl it seems the argmin n_config will be chosen. Just amin here.
best_configs = np.argmin(compute_cost, axis=3)
best_compute_cost = np.amin(compute_cost, axis=3)
assert len(all_possible_stage_costs), "no solution in auto stage construction."
for max_stage_cost in all_possible_stage_costs:
if max_stage_cost * num_microbatches >= best_cost:
break
if max_stage_cost - last_max_stage_cost < gap:
continue
cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
max_stage_cost, best_configs)
if cost < best_cost:
best_cost = cost
best_solution = solution
last_max_stage_cost = max_stage_cost
return best_cost, best_solution
"""This code is adapted from Alpa
https://github.com/alpa-projects/alpa/
with some changes. """
import operator
from functools import reduce
from typing import List, Tuple
import torch
import torch.distributed as dist
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
class DeviceMesh:
"""A logical view of a physical cluster. For example, we could view a physical cluster
with 16 devices as a device mesh with shape (2, 2, 4) or (4, 4).
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
mesh_shape (torch.Size, optional): shape of logical view.
mesh_alpha (List[float], optional): coefficients used for computing
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
communication cost (default: None)
init_process_group (bool, optional): initialize logical process group
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
"""
def __init__(self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
need_flatten: bool = True):
self.physical_mesh_id = physical_mesh_id
if logical_mesh_id is None:
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape
# map global rank into logical rank
self.convert_map = {}
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
# coefficient for alpha-beta communication model
if mesh_alpha is None:
mesh_alpha = [1] * len(self.mesh_shape)
if mesh_beta is None:
mesh_beta = [1] * len(self.mesh_shape)
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
self.init_process_group = init_process_group
self.need_flatten = need_flatten
if self.init_process_group:
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
if self.need_flatten and self._logical_mesh_id.dim() > 1:
self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
# self.mesh_beta)
@property
def shape(self):
return self.mesh_shape
@property
def num_devices(self):
return reduce(operator.mul, self.physical_mesh_id.shape, 1)
@property
def logical_mesh_id(self):
return self._logical_mesh_id
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
setattr(result, k, v)
return result
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
flatten_mesh_shape_size = len(self.mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group,
need_flatten=False)
def _global_rank_to_logical_rank_map(self, tensor, index_list):
'''
This method is a helper function to build convert_map recursively.
'''
for index, inner_tensor in enumerate(tensor):
if inner_tensor.numel() == 1:
self.convert_map[int(inner_tensor)] = index_list + [index]
else:
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
process_groups_dict = {}
check_duplicate_list = []
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
for axis, process_group in process_groups.items():
if axis not in process_groups_dict:
process_groups_dict[axis] = []
if process_group not in check_duplicate_list:
check_duplicate_list.append(process_group)
process_group_handler = dist.new_group(process_group)
process_groups_dict[axis].append((process_group, process_group_handler))
return process_groups_dict
def global_rank_to_logical_rank(self, rank):
return self.convert_map[rank]
def global_rank_to_process_groups_with_logical_rank(self, rank):
'''
Give a global rank and return all logical process groups of this rank.
for example:
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.global_rank_to_process_groups_with_logical_rank(0))
output:
# key is axis name
# value is a list of logical ranks in same axis with rank 0
{0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]}
'''
process_groups = {}
for d in range(self.logical_mesh_id.dim()):
for replacer in range(self.logical_mesh_id.shape[d]):
if d not in process_groups:
process_groups[d] = []
process_group_member = self.convert_map[rank].copy()
process_group_member[d] = replacer
process_groups[d].append(process_group_member)
return process_groups
def global_rank_to_process_groups_with_global_rank(self, rank):
'''
Give a global rank and return all process groups of this rank.
for example:
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.global_rank_to_process_groups_with_global_rank(0))
output:
# key is axis name
# value is a list of global ranks in same axis with rank 0
{0: [0, 4, 8, 12], 1: [0, 1, 2, 3]}
'''
logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank)
process_groups = {}
for dim, logical_ranks in logical_process_groups.items():
process_groups[dim] = []
for logical_rank in logical_ranks:
for g_rank, l_rank in self.convert_map.items():
if l_rank == logical_rank:
process_groups[dim].append(g_rank)
return process_groups
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.1)
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
0.01)
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.001)
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
class FlattenDeviceMesh(DeviceMesh):
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
super().__init__(physical_mesh_id,
mesh_shape,
mesh_alpha,
mesh_beta,
init_process_group=False,
need_flatten=False)
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
self.mesh_alpha = max(self.mesh_alpha)
self.mesh_beta = min(self.mesh_beta)
# Different from original process_groups_dict, rank_list is not stored
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
def create_process_numbers_for_logical_mesh(self):
'''
Build 1d DeviceMesh in column-major(0) and row-major(1)
for example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
num_devices = reduce(operator.mul, self.mesh_shape, 1)
process_numbers_dict = {}
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
return process_numbers_dict
def mix_gather_cost(self, num_bytes):
num_devices = reduce(operator.mul, self.mesh_shape, 1)
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
from ._base_engine import Engine
from .gradient_handler import *
__all__ = ['Engine']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from typing import Iterable, List, Optional, Type
from torch import Tensor
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
from colossalai.logging import get_dist_logger
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
It controls a iteration in training.
Args:
model (``torch.nn.Module``): The neural network model.
optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters.
criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
clip_grad_norm (float, optional): The norm of gradient clipping.
ophook_list (list): List of ophook.
verbose (bool): whether to display log info.
schedule (''BaseSchedule''): Runtime schedule.
Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
>>> model = ...
>>> criterion = ...
>>> optimizer = ...
>>> train_dataloader = ...
>>> engine, _, _, _ = colossalai.initialize(model, optimizer, criterion)
>>> engine.train()
>>> for inputs, labels in train_dataloader
>>> # set gradients to zero
>>> engine.zero_grad()
>>> # run forward pass
>>> outputs = engine(inputs)
>>> # compute loss value and run backward pass
>>> loss = engine.criterion(outputs, labels)
>>> engine.backward(loss)
>>> # update parameters
>>> engine.step()
The example of using Engine in training could be find in
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_. and
`Run resnet cifar10 with engine <https://github.com/hpcaitech/ColossalAI-Examples/blob/main/image/resnet/run_resnet_cifar10_with_engine.py>`_.
"""
def __init__(self,
model: Module,
optimizer: "ColossalaiOptimizer",
criterion: Optional[_Loss] = None,
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0,
ophook_list: Optional[List[BaseOpHook]] = None,
verbose: bool = True,
schedule: Optional[BaseSchedule] = None):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
self._clip_grad_norm = clip_grad_norm
self._verbose = verbose
self._logger = get_dist_logger()
# state
self.training = True # default
# build gradient handler
if gradient_handlers:
self._gradient_handlers = gradient_handlers
else:
self._gradient_handlers = []
if ophook_list is None:
self._ophook_list = []
else:
self._ophook_list = ophook_list
# build schedule
if schedule:
assert isinstance(schedule, BaseSchedule), \
f'expected schedule to be of type BaseSchedule, but got {type(schedule)}'
self._schedule = schedule
else:
self._schedule = NonPipelineSchedule()
if self.uses_pipeline:
self._schedule.pre_processing(self)
# register hook if any
if len(self._ophook_list) > 0:
register_ophooks_recursively(self._model, self._ophook_list)
@property
def ophooks(self):
"""show current activated ophooks"""
return self._ophook_list
@property
def model(self):
"""Model attached to the engine"""
return self._model
@property
def optimizer(self):
"""Optimizer attached to the engine"""
return self._optimizer
@property
def criterion(self):
"""Criterion attached to the engine"""
return self._criterion
@property
def schedule(self):
"""Schedule attached to the engine"""
return self._schedule
@property
def uses_pipeline(self):
"""show the pipeline parallel used or not"""
return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule))
def add_hook(self, ophook: Type[BaseOpHook]) -> None:
"""add necessary hook"""
# whether this hook exist
for h in self._ophook_list:
if type(h) == type(ophook):
logger = get_dist_logger()
logger.warning(f"duplicate hooks, at least two instance of {type(ophook)}")
self._ophook_list.append(ophook)
register_ophooks_recursively(self._model, self._ophook_list)
def remove_hook(self, ophook: Type[BaseOpHook]) -> None:
"""remove hook"""
logger = get_dist_logger()
logger.warning(f"removing hooks is currently not supported")
def zero_grad(self):
"""Set the gradient of parameters to zero
"""
self.optimizer.zero_grad()
def step(self):
"""Execute parameter update
"""
self._all_reduce_gradients()
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
return self.optimizer.step()
def backward(self, loss: Tensor):
"""Start backward propagation given the loss value computed by a loss function.
Args:
loss (:class:`torch.Tensor`): Loss value computed by a loss function.
"""
ret = self.optimizer.backward(loss)
for ophook in self._ophook_list:
ophook.post_iter()
return ret
def backward_by_grad(self, tensor, grad):
"""Start backward propagation given the gradient of the output tensor.
Args:
tensor (:class:`torch.Tensor`): Output tensor.
grad (:class:`torch.Tensor`): Gradient passed back to the output.
"""
ret = self.optimizer.backward_by_grad(tensor, grad)
for ophook in self._ophook_list:
ophook.post_iter()
return ret
def __call__(self, *args, **kwargs):
"""Run the forward step for the model.
Returns:
Tuple[:class:`torch.Tensor`] or :class:`torch.Tensor`: Output of the model.
"""
return self.model(*args, **kwargs)
def _all_reduce_gradients(self):
"""Handles all-reduce operations of gradients across different parallel groups.
"""
for handler in self._gradient_handlers:
handler.handle_gradient()
def execute_schedule(self, data_iter: Iterable, **kwargs):
"""Run the forward, loss computation, and backward for the model.
Returns a tuple of (output, label, loss).
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
"""
output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs)
return output, label, loss
def train(self):
"""Sets the model to training mode.
"""
self.training = True
self._model.train()
def eval(self):
"""Sets the model to evaluation mode.
"""
self.training = False
self._model.eval()
from typing import Iterable, List
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from colossalai.engine import BaseGradientHandler
from ._gradient_accumulation import (
GradAccumDataloader,
GradAccumGradientHandler,
GradAccumLrSchedulerByStep,
GradAccumOptimizer,
)
__all__ = [
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
'GradAccumGradientHandler'
]
def accumulate_gradient(model: nn.Module,
optimizer: Optimizer,
dataloader: Iterable,
accumulate_size: int,
gradient_handlers: List[BaseGradientHandler] = None,
lr_scheduler: _LRScheduler = None):
r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation.
Args:
model (:class:`torch.nn.Module`): your model object for gradient accumulation.
optimizer (:class:`torch.optim.Optimizer`): your optimizer object for gradient accumulation.
dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
your dataloader object, would be called like iter(dataloader)
accumulate_size (int): the number of steps to accumulate gradients
gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]):
list of gradient handler objects. Default is None.
lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
your ``lr_scheduler`` object for gradient accumulation. Defaults to None.
More details about `gradient_handlers` could be found in
`Gradient_handler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/engine/gradient_handler>`_.
More details about `lr_scheduler` could be found
`lr_scheduler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/nn/lr_scheduler>`_. and
`how to adjust learning rate <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
"""
optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model)
dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size)
if gradient_handlers is not None:
gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers]
if lr_scheduler is not None:
lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size)
return optimizer, dataloader, gradient_handlers, lr_scheduler
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Iterable, Tuple, Union
import torch.nn as nn
from torch import Tensor
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from colossalai.engine import BaseGradientHandler
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import conditional_context
class GradAccumOptimizer(ColossalaiOptimizer):
"""A wrapper for the optimizer to enable gradient accumulation by skipping the steps
before accumulation size is reached.
Args:
optim (:class:`torch.optim.Optimizer`): Your optimizer object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
model (:class:`torch.nn.Module`):
Your model object to check if it is DistributedDataParallel for special handling of no_sync() context.
"""
def __init__(self, optim: Optimizer, accumulate_size: int, model: nn.Module = None):
super().__init__(optim)
self.accumulate_size = accumulate_size
self.accumulate_step = 0
# handle pytorch ddp auto all reduce
self.model = model
self.is_torch_ddp = isinstance(self.model, DistributedDataParallel)
def zero_grad(self, *args, **kwargs) -> None:
"""
Set all gradients to zero.
Args:
*args: positional arguments for the optimizer wrapped
**kwargs: keyword arguments for the optimizer wrapped
"""
if self.accumulate_step == 0:
self.optim.zero_grad(*args, **kwargs)
def step(self, *args, **kwargs) -> None:
"""
Update the model parameters.
Args:
*args: positional arguments for the optimizer wrapped
**kwargs: keyword arguments for the optimizer wrapped
"""
if self.accumulate_step < self.accumulate_size:
return None
else:
self.accumulate_step = 0
return self.optim.step(*args, **kwargs)
def clip_grad_norm(self, model: nn.Module, max_norm: float) -> None:
"""
Clip gradients by norm.
Args:
model (:class:`torch.nn.Module`): a torch module instance
max_norm (float): the max norm for gradient clipping
"""
if self.accumulate_step < self.accumulate_size:
pass
else:
self.optim.clip_grad_norm(model, max_norm)
def backward(self, loss: Tensor) -> None:
"""Execute backward pass.
Args:
loss (:class:`torch.Tensor`): the loss value.
"""
self.accumulate_step += 1
if self.is_torch_ddp:
no_sync = self.accumulate_step < self.accumulate_size
with conditional_context(self.model.no_sync(), enable=no_sync):
scaled_loss = loss / self.accumulate_size
self.optim.backward(scaled_loss)
else:
scaled_loss = loss / self.accumulate_size
self.optim.backward(scaled_loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
"""Execute backward pass given the gradients of the output.
Args:
loss (:class:`torch.Tensor`): the loss value.
grad (:class:`torch.Tensor`): the output gradient.
"""
self.accumulate_step += 1
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
if no_sync:
with self.model.no_sync():
self.optim.backward_by_grad(tensor, grad)
else:
self.optim.backward_by_grad(tensor, grad)
class GradAccumDataloader:
"""A wrapper for dataloader to enable gradient accumulation by dropping the last incomplete steps.
Note:
The dataloader would drop the last incomplete steps for gradient accumulation.
For example, if a dataloader has 10 batches of data and accumulate size is 4. The model parameters will
be updated only twice at step 4 and step 8. The last two batches of data do not form a complete 4-step cycle.
Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader,
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
Args:
dataloader (``Iterable``): Your dataloader object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
def __init__(self, dataloader: Iterable, accumulate_size: int) -> None:
self.dataloader = dataloader
self.consume_remain_data = not isinstance(dataloader, DataLoader)
self.steps_per_epoch = len(dataloader) - len(dataloader) % accumulate_size
def __getattr__(self, __name: str) -> Any:
return getattr(self.dataloader, __name)
def __len__(self) -> int:
return self.steps_per_epoch
def __iter__(self) -> Iterable:
self._cur_step = 0
self._dataiter = iter(self.dataloader)
return self
def __next__(self) -> Union[Tensor, Tuple[Tensor]]:
if self._cur_step < self.steps_per_epoch:
self._cur_step += 1
data = next(self._dataiter)
if self._cur_step == self.steps_per_epoch and self.consume_remain_data:
# this is to handle non standard pytorch dataloader
# such as dali dataloader
while True:
try:
_ = next(self._dataiter)
except StopIteration:
break
return data
else:
raise StopIteration
class GradAccumLrSchedulerByStep(_LRScheduler):
"""A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps
before accumulation size is reached.
Args:
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`):
Your ``lr_scheduler`` object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
def __init__(self, lr_scheduler: _LRScheduler, accumulate_size: int) -> None:
self.lr_scheduler = lr_scheduler
self.accumulate_size = accumulate_size
self.accumulate_step = 0
@staticmethod
def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int) -> int:
"""
Computes the number of effective training iterations. An effective iteration is defined
as the the aggregation of <accumulate_size> iterations. For examples, if accumulate_size = 4,
then 4 iterations are considered as one effective iteration.
Args:
dataloader (``Iterable``): Your dataloader object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
return len(dataloader) // accumulate_size
def __getattr__(self, __name: str) -> Any:
return getattr(self.lr_scheduler, __name)
def step(self, *args, **kwargs) -> None:
"""
Update the learning rate.
Args:
*args: positional arguments for the lr scheduler wrapped.
**kwargs: keyword arguments for the lr scheduler wrapped.
"""
self.accumulate_step += 1
if self.accumulate_step < self.accumulate_size:
pass
else:
self.accumulate_step = 0
self.lr_scheduler.step(*args, **kwargs)
def get_lr(self) -> Tensor:
"""
Compute the next learning rate.
Returns:
Tensor: the upcoming learning rate.
"""
return self.lr_scheduler.get_lr()
def get_last_lr(self) -> Tensor:
"""
Returns the current learning rate.
Returns:
Tensor: the current learning rate.
"""
return self.lr_scheduler.get_last_lr()
def print_lr(self, *args, **kwargs) -> None:
"""
Print he learning rate.
Args:
*args: positional arguments for the lr scheduler wrapped.
**kwargs: keyword arguments for the lr scheduler wrapped.
"""
self.lr_scheduler.print_lr(*args, **kwargs)
def state_dict(self) -> dict:
"""
Returns the states of the lr scheduler as dictionary.
Returns:
dict: the states of the lr scheduler.
"""
return self.lr_scheduler.state_dict()
def load_state_dict(self, state_dict: dict) -> None:
"""
Load the states of the lr scheduler from a dictionary object.
Returns:
dict: the states of the lr scheduler.
"""
self.lr_scheduler.load_state_dict(state_dict)
class GradAccumGradientHandler:
r"""A wrapper for the gradient handler to enable gradient accumulation by skipping the steps
before accumulation size is reached.
Args:
grad_handler (:class:`colossalai.engine.BaseGradientHandler`):
Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`.
accumulate_size (int): The number of steps to accumulate gradients.
More details about ``gradient_handlers`` could be found in
`Gradient_handler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/engine/gradient_handler>`_.
"""
def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None:
assert isinstance(grad_handler, BaseGradientHandler), \
f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}'
self.grad_handler = grad_handler
self.accumulate_size = accumulate_size
self.accumulate_step = 0
def handle_gradient(self) -> None:
"""
Handle gradients reduction only in the last gradient accumulation step.
"""
self.accumulate_step += 1
if self.accumulate_step < self.accumulate_size:
pass
else:
self.accumulate_step = 0
self.grad_handler.handle_gradient()
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._moe_gradient_handler import MoeGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler
__all__ = [
'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
'MoeGradientHandler', 'SequenceParallelGradientHandler'
]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseGradientHandler(ABC):
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
before optimization.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def __init__(self, model, optimizer):
self._model = model
self._optimizer = optimizer
@abstractmethod
def handle_gradient(self):
"""A method to accumulate gradients across different parallel groups. Users should
write their own functions or just use the functions in pre-defined subclasses.
"""
pass
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
@GRADIENT_HANDLER.register_module
class DataParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
# TODO: add memory buffer
if gpc.data_parallel_size > 1:
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA))
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
@GRADIENT_HANDLER.register_module
class MoeGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group and
moe model parallel. A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def __init__(self, model, optimizer=None):
super().__init__(model, optimizer)
def handle_gradient(self):
"""A method running an all-reduce operation in a data parallel group.
Then running an all-reduce operation for all parameters in experts
across moe model parallel group
"""
global_data = gpc.data_parallel_size
if global_data > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism
if 1 in epsize_param_dict:
bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA))
for ep_size in epsize_param_dict:
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
bucket_allreduce(param_list=epsize_param_dict[ep_size],
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)
#!/usr/bin/env python
from collections import defaultdict
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in sub parallel groups.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among all sub pipeline parallel groups.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in sub pipeline parallel groups.
"""
if gpc.pipeline_parallel_size > 1:
# bucketize and all-reduce
buckets = defaultdict(lambda: defaultdict(list))
# Pack the buckets.
for param in self._model.parameters():
group = getattr(param, 'pipeline_shared_module_pg', None)
if param.requires_grad and group is not None and (
(hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null())
or param.grad is not None):
tp = param.data.type()
buckets[group][tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads.
for group, group_buckets in buckets.items():
for tp, bucket in group_buckets.items():
grads = [
param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data
for param in bucket
]
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
@GRADIENT_HANDLER.register_module
class SequenceParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1:
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP))
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module
class ZeROGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
This class is specialized with ZeRO optimization.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
self._optimizer.sync_grad()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment