Commit 5a560a06 authored by Jiarui Fang's avatar Jiarui Fang Committed by Frank Lee
Browse files

Feature/zero (#279)



* add zero1 (#209)

* add zero1

* add test zero1

* update zero stage 1 develop (#212)

* Implement naive zero3 (#240)

* naive zero3 works well

* add zero3 param manager

* add TODOs in comments

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* fix bugs of hook and add unit tests (#252)

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* polish code and add state dict hook

* fix bug

* update unit test

* refactor reconstructed zero code

* clip_grad support zero3 and add unit test

* add unit test for Zero3ParameterManager

* [WIP] initialize the shard param class

* [WIP] Yet another sharded model implementation (#274)

* [WIP] initialize the shard param class

* [WIP] Yes another implementation of shardModel. Using a better hook method.

* torch.concat -> torch.cat

* fix test_zero_level_1.py::test_zero_level_1 unitest

* remove deepspeed implementation and refactor for the reconstructed zero module

* polish zero dp unittests
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
parent 08eccfe6
from .gradient_store import GradientStore
from .parameter_store import ParameterStore
from .bucket_store import BucketStore
from .tensor_bucket import TensorBucket
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']
\ No newline at end of file
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
class BaseStore:
def __init__(self, dp_parallel_mode=ParallelMode.DATA):
self._world_size = gpc.get_world_size(dp_parallel_mode)
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
@property
def world_size(self):
return self._world_size
@property
def local_rank(self):
return self._local_rank
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from .base_store import BaseStore
class BucketStore(BaseStore):
def __init__(self, dp_parallel_mode):
super().__init__(dp_parallel_mode)
self._grads = dict()
self._params = dict()
self._num_elements_in_bucket = dict()
self.reset()
def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements
def add_grad(self, tensor, reduce_rank: int = None):
self._grads[reduce_rank].append(tensor)
def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor)
def reset(self):
keys = [None] + list(range(self._world_size))
self._grads = {rank: [] for rank in keys}
self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys}
def reset_by_rank(self, reduce_rank=None):
self._grads[reduce_rank] = []
self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0
def get_grad(self, reduce_rank: int = None):
return self._grads[reduce_rank]
def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]
from typing import List
from torch import Tensor
from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args):
super().__init__(*args)
# bookkeeping data structures
self._averaged_gradients = dict()
# for backward reduction hooks
self._grad_acc_objs = []
def add_accumulate_grad_object(self, obj):
"""
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully.
:param obj: An object of :class:`AccumulateGrad` class
:type obj: :class:`AccumulateGrad`
"""
self._grad_acc_objs.append(obj)
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
"""
Return average gradients of a parameter group
:param group_id: The index of parameter group
:type group_id: int
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
"""
return self._averaged_gradients[group_id]
def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
"""
Append an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor: torch.Tensor
"""
if group_id in self._averaged_gradients:
self._averaged_gradients[group_id].append(tensor)
else:
self._averaged_gradients[group_id] = [tensor]
def reset_average_gradients_by_group(self, group_id: int) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
:param group_id: The index of a parameter group
:type group_id: int
"""
self._averaged_gradients[group_id] = []
from .base_store import BaseStore
from torch import Tensor
from typing import List
class ParameterStore(BaseStore):
def __init__(self, dp_paralle_mode):
super().__init__(dp_paralle_mode)
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()
self._rank_group_id_to_flat_fp16_param = dict()
# param reduction data structures
self._is_param_reduced = dict()
self._reduced_param = []
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
"""
Set the mapping between parameter to rank, each parameter should be owned by a rank.
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:param rank: The rank of which the process is responsible for updating the parameter
:type rank: int
"""
self._fp16_param_to_rank[tensor] = rank
def get_param_rank(self, tensor: Tensor) -> int:
"""
Gives the rank which the parameter belongs to
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
"""
return self._fp16_param_to_rank[tensor]
def belongs_to_current_rank(self, tensor) -> bool:
"""
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:return: True if the parameter should be updated by the current rank. Otherwise false.
:rtype: bool
"""
tensor_rank = self._fp16_param_to_rank[tensor]
return tensor_rank == self._local_rank
def add_fp16_param_list_by_rank_group(self, rank, group_id,
tensor_list) -> None:
if rank not in self._rank_groupid_to_fp16_param_list:
self._rank_groupid_to_fp16_param_list[rank] = dict()
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(
tensor_list)
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_groupid_to_fp16_param_list[rank][group_id]
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_fp16_param:
self._rank_group_id_to_flat_fp16_param[rank] = dict()
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
def is_param_reduced(self, tensor):
return self._is_param_reduced[tensor]
def set_param_reduction_state(self, tensor, state):
self._is_param_reduced[tensor] = state
def get_param_reduction_states(self):
return self._is_param_reduced
def reset_previous_reduced_params(self):
self._reduced_param = []
def add_previous_reduced_param(self, tensor):
self._reduced_param.append(tensor)
def clear_grads_of_previous_reduced_params(self):
if len(self._reduced_param) > 0:
for param in self._reduced_param:
param.grad = None
self.reset_previous_reduced_params()
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class TensorBucket:
def __init__(self, size):
self._max_size = size
self._current_size = 0
self._bucket = []
@property
def max_size(self):
return self._max_size
@property
def current_size(self):
return self._current_size
def is_full_or_oversized(self):
return self._current_size >= self._max_size
def is_empty(self):
return len(self._bucket) == 0
def add_to_bucket(self, tensor, allow_oversize=False):
tensor_size = tensor.numel()
if not allow_oversize and self.will_exceed_max_size(tensor_size):
msg = f"The param bucket max size {self._max_size} is exceeded" \
+ f"by tensor (size {tensor_size})"
raise RuntimeError(msg)
self._bucket.append(tensor)
self._current_size += tensor_size
def will_exceed_max_size(self, tensor_size):
expected_size = self._current_size + tensor_size
return expected_size > self._max_size
def get_bucket(self):
return self._bucket
def empty(self):
self._bucket = []
self._size = 0
def flatten(self):
return _flatten_dense_tensors(self._bucket)
def unflatten_and_copy(self, flat_tensor):
unflattened_tensor_list = _unflatten_dense_tensors(
flat_tensor, self._bucket)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
from itertools import groupby
from colossalai.utils.cuda import get_current_device
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
from torch.optim import Optimizer
from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer
from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
from functools import partial
class ShardedOptimizer(ColossalaiOptimizer):
def __init__(
self,
optimizer: Optimizer,
# grad scaler config
initial_scale=2**32,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale: int = 2**32,
# grad clipping
clip_grad_norm=2.0,
verbose=False,
# communication
reduce_bucket_size=500000000,
communication_dtype=torch.float16,
overlap_communication=False,
# stage 2
partition_grad=False,
dp_parallel_mode=ParallelMode.DATA,
mp_parallel_mode=ParallelMode.MODEL,
# cpu offload
cpu_offload=False):
# TODO: add support for
# 1. fp16 master weights
# 2. contiguous gradients
# 3. cpu offload
# 4. support when some parameters requires_grad = False
self._optimizer = optimizer
self._dtype = self._optimizer.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
self._verbose = verbose
# stage 2
self._partition_grads = partition_grad
# cpu_offload
self._cpu_offload = cpu_offload
# get process groups
self._dp_parallel_mode = dp_parallel_mode
self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
self._world_size = gpc.get_world_size(dp_parallel_mode)
self._dp_group = gpc.get_group(dp_parallel_mode)
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
self._mp_group = gpc.get_group(mp_parallel_mode)
else:
self._mp_group = None
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
self._fp32_flat_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
# gradient scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
verbose=verbose)
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
# gradient clipping
self._clip_grad_norm = clip_grad_norm
# check argument conflict
self._sanity_checks()
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self._dp_parallel_mode)
self._grad_store = GradientStore(self._dp_parallel_mode)
self._bucket_store = BucketStore(self._dp_parallel_mode)
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self._optimizer.param_groups):
params = param_group['params']
# add the fp16 params to fp16_param_groups for bookkeeping
self._fp16_param_groups[group_id] = params
# assign parameters to ranks
# the params in the list are sorted
params_per_rank = self._partition_param_list(params)
# store the mapping between param to rank
# each param should belong to only one rank
for rank, params in enumerate(params_per_rank):
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
for param in params:
self._param_store.set_param_to_rank(param, rank)
# move to cpu to make room to create the flat tensor
move_tensor(params, device='cpu')
# flatten the reordered tensors
for rank in range(self._world_size):
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
flat_tensor = flatten(tensor_list)
flat_tensor = flat_tensor.cuda()
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
# sync parameters
for rank in range(self._world_size):
flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id)
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
# create a copy of fp32 weights of the parameters for which this rank is responsible
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
fp32_flat_current_rank = fp16_flat_current_rank.clone().float().detach()
device = 'cpu' if self._cpu_offload else get_current_device()
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
fp32_flat_current_rank.requires_grad = True
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group['params'] = [fp32_flat_current_rank]
# set reduction state
for param in self._fp16_param_groups[group_id]:
self._param_store.set_param_reduction_state(param, False)
# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
self._comm_stream = torch.cuda.Stream()
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()
self._initialize_optimizer_states()
@property
def loss_scale(self):
return self.grad_scaler.scale
@property
def num_param_groups(self):
return len(self._fp16_param_groups)
def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)]
# partititon the parameters in a greedy fashion
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
for param in sorted_params:
# allocate this parameter to the rank with
# the smallest numel for load balancing purpose
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
numel_per_rank[rank_to_go] += param.numel()
if self._verbose:
self._logger.info(f'Number of elements on ranks: {numel_per_rank}',
ranks=[0],
parallel_mode=self._dp_parallel_mode)
return params_per_rank
def _initialize_optimizer_states(self):
# create a dummy zero tensor which has the same shape as that of the param
# set this dummpy zero tensor as grad
for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)):
fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp32_partition_grad = torch.zeros_like(fp32_partition_param)
fp32_partition_param.grad = fp32_partition_grad
# update the parameter with zero gradients for initialization of optimizer states
self._optimizer.step()
# remove the grad of the paramter to save memory
for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items():
fp32_flat_tensor.grad = None
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
assert self._dtype == torch.float16, \
f'Parameters are expected to be of type torch.float16, but got {self._dtype}'
###########################################################
# Backward Reduction Hook
###########################################################
def _attach_reduction_hook(self):
# we iterate over the fp16 params
# on each param, we register a hook to its AccumulateGrad object
for group_id in range(self.num_param_groups):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
if param.requires_grad:
# determines the reduction destionation rank
# this is only valid for stage 2
# dst_rank = None means using all-reduce
# else using reduce
if self._partition_grads:
reduce_rank = self._param_store.get_param_rank(param)
else:
reduce_rank = None
def _define_and_attach(param, reduce_rank):
# get the AccumulateGrad object of the param itself
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
param=param,
reduce_rank=reduce_rank)
# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args):
reduction_func()
accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank)
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
param_size = param.numel()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_in_bucket(reduce_rank)
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_grad(param.grad, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
def _reduce_grads_in_bucket(self, reduce_rank=None):
# reduce grads
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
# use communication stream if overlapping
# communication with computation
if self._overlap_communication:
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
for param in params_in_bucket:
# the is_param_reduced flag should be False showing that
# this param is not reduced before calling self._reduce_grads_by_rank
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
# update the flag
self._param_store.set_param_reduction_state(param, True)
# if partition grads = True
# we do not keep the gradient after reduction
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
if self._overlap_communication:
# we need to keep this gradient for now as reduction may
# be completed yet since it is using a different cuda stream
self._param_store.add_previous_reduced_param(param)
else:
param.grad = None
self._bucket_store.reset_by_rank(reduce_rank)
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)
for tensor_list in grad_buckets_by_dtype:
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
##############################
# Reduction Utility Function #
##############################
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduced_flat = reduce_tensor(tensor=flat,
dtype=self._communication_dtype,
dst_rank=reduce_rank,
parallel_mode=self._dp_parallel_mode)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
################################
# torch.optim.Optimizer methods
################################
def backward(self, loss, retain_graph=True):
loss = self.loss_scale * loss
loss.backward(retain_graph=retain_graph)
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
will be set to None to save memory.
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
for group_id, param_group in self._fp16_param_groups.items():
for param in param_group:
if set_to_none:
param.grad = None
else:
if param.grad is not None:
param.grad.detach()
param.grad.zero_()
####################
# Update Parameter #
####################
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
# check for overflow
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []
norm_groups = []
for group_id in range(self.num_param_groups):
# compute norm
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
dp_group=self._dp_group,
mp_group=self._mp_group)
norm_groups.append(norm_group)
# create flat gradient for the flat fp32 params
fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
flat_fp16_avg_grads = flatten(fp16_avg_grads)
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
assert param_shape == flat_fp32_avg_grads.shape, \
f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}'
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
self._grad_store._averaged_gradients[group_id] = []
self._grad_store._averaged_gradients[group_id] = []
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
# update the parameters
self._optimizer.step()
# release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
# update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device)
fp16_param.data.copy_(fp32_param)
# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups):
for rank in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
##################
# FP16 Utilities #
##################
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group_id in range(len(self._fp16_param_groups)):
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
break
# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
# all-reduce over model parallel group
if self._mp_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
if self._found_overflow.item() > 0:
return True
else:
return False
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self._clip_grad_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
grad.data.mul_(1. / combined_scale)
############################
# Gradient Synchronization #
############################
def sync_grad(self):
if not self._partition_grads:
self._reduce_grad_stage1()
else:
# TODO: support async comm in reduce
self._reduce_grad_stage2()
# update param already reduced flag
reduction_states = self._param_store.get_param_reduction_states()
for tensor, state in reduction_states.items():
reduction_states[tensor] = False
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# accumulate gradient
avg_gradients = self._grad_store._averaged_gradients
for group_id in range(self.num_param_groups):
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
if group_id not in avg_gradients:
avg_gradients[group_id] = []
param_idx = 0
for param in param_group:
if param.grad is not None:
if len(avg_gradients[group_id]) == param_idx:
avg_gradients[group_id].append(param.grad)
else:
avg_gradients[group_id][param_idx].add_(param.grad)
param_idx += 1
# the gradients needed are stored in the avg_gradients buffer
# thus, can clear this
self.zero_grad()
def _reduce_grad_stage1(self):
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
if not self._overlap_communication:
for group_id in range(len(self._fp16_param_groups)):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._reduce_and_remove_grads_by_bucket(param)
# we need to reduce the gradients
# left in the communication bucket
self._reduce_grads_in_bucket()
def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks
# are attached in the __init__ function, so we
# only need to reduce the gradients
# left in the communication bucket
for reduce_rank in range(self._world_size):
self._reduce_grads_in_bucket(reduce_rank)
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
import math
import torch
import torch.distributed as dist
try:
from deepspeed.git_version_info import version
from deepspeed.moe.utils import is_moe_param
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
except ImportError:
pass
from packaging import version as pkg_version
from torch._six import inf
from torch.distributed.distributed_c10d import _get_global_rank
from torch.optim import Optimizer
from colossalai.core import global_context as gpc
from colossalai.utils import report_memory_usage
from colossalai.utils.common import is_model_parallel_parameter
from .loss_scaler import LossScaler, DynamicLossScaler
from colossalai.context import ParallelMode
# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
def input(msg):
return
def split_half_float_double(tensors):
dtypes = [
"torch.cuda.HalfTensor",
"torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor"
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def isclose(a, b, rtol=1e-09, atol=0.0):
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
return (alignment - remainder) if remainder else remainder
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")
class ZeroRedundancyOptimizer_Level_2(Optimizer):
"""
ZeroRedundancyOptimizer_Level_2 designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
"""
def __init__(self,
init_optimizer,
dp_parallel_mode=ParallelMode.DATA,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=False,
contiguous_gradients=True,
reduce_bucket_size=500000000,
allgather_bucket_size=5000000000,
reduce_scatter=True,
overlap_comm=False,
cpu_offload=False,
clip_grad=0.0,
allreduce_always_fp32=False,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
ignore_unused_parameters=True,
round_robin_gradients=False,
fp16_master_weights_and_gradients=False):
# mpu = None is removed from the parameter list
# tensor parallel will be automatically detected later
# LSG: default arguments for compatibility
has_moe_layers = False
partition_grads = True
expert_parallel_group = None
expert_data_parallel_group = None
self.timers = None
self.defaults = init_optimizer.defaults
dp_process_group = gpc.get_group(dp_parallel_mode)
if gpc.get_world_size(dp_parallel_mode) == 1:
partition_grads = False # for compatibility with dp size = 1
self.verbose = verbose
if dist.get_rank() == 0 and self.verbose:
print(f"Reduce bucket size {reduce_bucket_size}")
print(f"Allgather bucket size {allgather_bucket_size}")
print(f"CPU Offload: {cpu_offload}")
print(
f'Round robin gradient partitioning: {round_robin_gradients}')
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
# Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten
# ZeRO stage 1 (False) or 2 (True)
self.partition_gradients = partition_grads
self.reduce_scatter = reduce_scatter
self.overlap_comm = overlap_comm
self.cpu_offload = cpu_offload
self.deepspeed_adam_offload = cpu_offload
self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'
self.dp_process_group = dp_process_group
# expert parallel group
self.ep_process_group = expert_parallel_group
# data parallel group for experts
self.expert_dp_process_group = expert_data_parallel_group
# data parallel size for non-experts
dp_size = dist.get_world_size(group=self.dp_process_group)
# For MoE models this maybe different for different param group
# It will be modified during MoE setup later in the init
self.real_dp_process_group = [
dp_process_group for i in range(len(self.optimizer.param_groups))
]
self.partition_count = [dp_size for i in range(
len(self.optimizer.param_groups))]
self.is_gradient_accumulation_boundary = True
# CPU-Offload requires contiguous gradients
self.contiguous_gradients = contiguous_gradients or cpu_offload
self.has_moe_layers = has_moe_layers
if self.has_moe_layers:
self._configure_moe_settings()
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_world_size(ParallelMode.TENSOR) == 1:
self.model_parallel_group = None
self.model_parallel_rank = 0
else:
self.model_parallel_group = gpc.get_group(ParallelMode.TENSOR)
self.model_parallel_rank = gpc.get_local_rank(ParallelMode.TENSOR)
self.overflow = False
self.clip_grad = clip_grad
self.allreduce_always_fp32 = allreduce_always_fp32
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.ignore_unused_parameters = ignore_unused_parameters
self.round_robin_gradients = round_robin_gradients
self.extra_large_param_to_reduce = None
self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients
if self.fp16_master_weights_and_gradients:
assert self.cpu_offload and type(self.optimizer) in [
DeepSpeedCPUAdam], f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32. Currenty only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}. Either disable fp16_master_weights_and_gradients or enable ZeRO-2 Offload with DeepSpeedCPUAdam"
if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
# param partitioned by data parallel degree
# this will contain a list of equal sized tensors
# each of which will be updated by a different process
self.parallel_partitioned_fp16_groups = []
# a single 32-bit partition of the parallel partitioned parameters
# that this process will update
self.single_partition_of_fp32_groups = []
# param partition info
# These are the parameters in each group that will not be updated by this process directly
self.params_not_in_partition = []
# These are the parameters that will be updated by this process directly
self.params_in_partition = []
# Offset from the first paramter in the the self.params_in_partition
# the parameter boundaries may not align with partition boundaries
# so we need to keep track of the offset
self.first_offset = []
# number of elements per partition in each group
self.partition_size = []
# align nccl all-gather send buffers to 4-bye boundary
# 4-byte alignment/sizeof(fp16) = 2
self.nccl_start_alignment_factor = 2
assert (
allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
self.all_reduce_print = False
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
self.round_robin_fp16_groups = []
self.round_robin_fp6_indices = []
# padding on each partition for alignment purposes
self.groups_padding = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
# push this group to list before modify
# TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group
self.fp16_groups.append(param_group['params'])
# Record padding required to align group to world size
if partition_id == dist.get_world_size(
group=self.real_dp_process_group[i]) - 1:
padding = get_alignment_padding(self.fp16_groups[i],
self.partition_count[i])
else:
padding = 0
self.groups_padding.append(padding)
# not sure why apex was cloning the weights before flattening
# removing cloning here
if self.verbose:
report_memory_usage(f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_groups[i])
if self.verbose:
report_memory_usage(f"After moving param group {i} to CPU")
# Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
# This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
# For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging
# to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).
if self.round_robin_gradients:
round_robin_tensors, round_robin_indices = self._round_robin_reorder(
self.fp16_groups[i],
dist.get_world_size(group=self.real_dp_process_group[i])
)
else:
round_robin_tensors = self.fp16_groups[i]
round_robin_indices = list(range(len(self.fp16_groups[i])))
self.round_robin_fp16_groups.append(round_robin_tensors)
self.round_robin_fp6_indices.append(round_robin_indices)
# create flat buffer in CPU and move to GPU
self.fp16_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.round_robin_fp16_groups[i],
self.nccl_start_alignment_factor *
dist.get_world_size(group=self.real_dp_process_group[i])).cuda(
torch.cuda.current_device()))
if self.verbose:
report_memory_usage(
f"After flattening and moving param group {i} to GPU")
if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
report_memory_usage(
f"After Flattening and after emptying param group {i} cache")
# set model fp16 weight to slices of flattened buffer
self._update_model_fp16_weights(i)
# divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition
data_parallel_partitions = self.get_data_parallel_partitions(
self.fp16_groups_flat[i],
i)
self.parallel_partitioned_fp16_groups.append(
data_parallel_partitions)
# verify that data partition start locations are 4-byte aligned
for partitioned_data in data_parallel_partitions:
assert (partitioned_data.data_ptr() %
(2 * self.nccl_start_alignment_factor) == 0)
# a partition of the fp32 master weights that will be updated by this process
if not fp16_master_weights_and_gradients:
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_fp16_groups[i][partition_id].to(
self.device).clone().float().detach())
else:
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_fp16_groups[i][partition_id].to(
self.device).clone().half().detach())
# modify optimizer of have flat master weight
self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(
group=self.real_dp_process_group[i])
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
self.round_robin_fp16_groups[i],
partition_size,
partition_id)
self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.params_not_in_partition.append(params_not_in_partition)
self.first_offset.append(first_offset)
for rank in range(dist.get_world_size()):
if dist.get_rank() == rank and self.verbose:
print(
f"Rank: {rank} partition count {self.partition_count} and sizes{[(p.numel(), self.is_moe_param_group[i] if hasattr(self, 'is_moe_param_group') else False) for i, p in enumerate(self.single_partition_of_fp32_groups)]} "
)
dist.barrier()
# exit(0)
self.reduce_bucket_size = int(reduce_bucket_size)
self.allgather_bucket_size = int(allgather_bucket_size)
self.reduction_event = torch.cuda.Event(
enable_timing=False, blocking=False)
self.reduction_stream = torch.cuda.Stream()
self.cpu_computation_stream = torch.cuda.Stream()
self.copy_grad_stream = torch.cuda.Stream()
self.callback_queued = False
self.param_dict = {}
# map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
self.params_already_reduced = []
self._release_ipg_buffers()
self.previous_reduced_grads = None
self.ipg_bucket_has_moe_params = False
# simplified param id
self.param_id = {}
largest_param_numel = 0
count = 0
for i, params_group in enumerate(self.fp16_groups):
for param in params_group:
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
if param.numel() > largest_param_numel:
largest_param_numel = param.numel()
count = count + 1
for param_group in self.params_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(
param)] = True
for param_group in self.params_not_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(
param)] = False
if self.cpu_offload:
self.accumulated_grads_in_cpu = {}
self.norm_for_param_grads = {}
self.local_overflow = False
self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = torch.zeros(
largest_param_numel,
device=self.device,
dtype=self.dtype).pin_memory()
self.temp_grad_buffer_for_gpu_offload = torch.zeros(
largest_param_numel,
device=torch.cuda.current_device(),
dtype=self.dtype)
for i, params_group in enumerate(self.fp16_groups):
self.get_grad_position(i,
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i])
# mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {}
# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
# number of grads in partition that still need to be computed
self.remaining_grads_in_partition = {}
# total number of grads in partition
self.total_grads_in_partition = {}
# stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
# stores the offset at which a parameter gradient needs to be inserted in a partition
self.grad_partition_insertion_offset = {}
# the offset in the gradient at which it must be inserted at the beginning of the partition
self.grad_start_offset = {}
# will store the averaged gradients required by this partition
self.averaged_gradients = {}
# store index of first parameter in each partition
self.first_param_index_in_partition = {}
# initializes all data structures for implementing gradient partitioning
self.initialize_gradient_partitioning_data_structures()
# resets the data structure value for the next backward propagation
self.reset_partition_gradient_structures()
# creates backward hooks for gradient partitioning
if self.partition_gradients or self.overlap_comm:
self.create_reduce_and_remove_grad_hooks()
# we may have a way of fusing dynamic scale. Do not support for now
if self.dtype == torch.float or not dynamic_loss_scale:
loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=loss_scale_value)
cur_iter = 0
else:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
if self.verbose:
report_memory_usage("Before initializing optimizer states")
self.initialize_optimizer_states()
if self.verbose:
report_memory_usage("After initializing optimizer states")
if dist.get_rank() == 0:
print(f"optimizer state initialized")
if dist.get_rank(group=self.dp_process_group) == 0:
report_memory_usage(f"After initializing ZeRO optimizer")
def _configure_moe_settings(self):
assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
def is_moe_group(group):
return 'moe' in group and group['moe']
assert any([is_moe_group(group) for group in
self.optimizer.param_groups]), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
self.is_moe_param_group = []
for i, group in enumerate(self.optimizer.param_groups):
if is_moe_group(group):
assert all(
[is_moe_param(param) for param in group['params']]), "All params in MoE group must be MoE params"
self.real_dp_process_group[i] = self.expert_dp_process_group
self.partition_count[i] = dist.get_world_size(
group=self.expert_dp_process_group)
self.is_moe_param_group.append(True)
else:
self.is_moe_param_group.append(False)
assert self.expert_dp_process_group is not None, "Expert data parallel group should be configured with MoE"
assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE"
def _update_model_fp16_weights(self, group_index):
updated_params = self.unflatten(self.fp16_groups_flat[group_index],
self.round_robin_fp16_groups[group_index])
for p, q in zip(self.round_robin_fp16_groups[group_index], updated_params):
p.data = q.data
# set model fp16 weight to slices of reordered flattened buffer
for param_index, param in enumerate(self.fp16_groups[group_index]):
new_index = self.round_robin_fp6_indices[group_index][param_index]
param.data = self.round_robin_fp16_groups[group_index][new_index].data
def _round_robin_reorder(self, tensor_list, num_partitions):
# disable round robin if need to debug something
# return tensor_list, list(range(len(tensor_list)))
partition_tensors = {}
for i, tensor in enumerate(tensor_list):
j = i % num_partitions
if not j in partition_tensors:
partition_tensors[j] = []
partition_tensors[j].append((i, tensor))
reordered_tensors = []
reordered_indices = {}
for partition_index in partition_tensors.keys():
for i, (original_index, tensor) in enumerate(partition_tensors[partition_index]):
reordered_indices[original_index] = len(reordered_tensors)
reordered_tensors.append(tensor)
return reordered_tensors, reordered_indices
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
self.grads_in_partition = None
self.grads_in_partition_offset = 0
def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
single_grad_partition = torch.zeros(
int(self.partition_size[i]),
dtype=self.single_partition_of_fp32_groups[i].dtype,
device=self.device)
self.single_partition_of_fp32_groups[
i].grad = single_grad_partition.pin_memory(
) if self.cpu_offload else single_grad_partition
self.optimizer.step()
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None # class init
return
#########################################################################
#################### ZeRO Stage 1 - reduce gradients ####################
#########################################################################
def reduce_gradients(self, pipeline_parallel=False):
world_size = dist.get_world_size(self.dp_process_group)
my_rank = dist.get_rank(self.dp_process_group)
# with PP we must create ipg buffer, since backward is handled outside zero
if pipeline_parallel and self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(int(self.reduce_bucket_size),
dtype=self.dtype,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0)
self.ipg_index = 0
if not self.overlap_comm:
for i, group in enumerate(self.fp16_groups):
for param in group:
if param.grad is not None:
self.reduce_ready_partitions_and_remove_grads(param, i)
# reduce any pending grads in either hook/non-hook case
self.overlapping_partition_gradients_reduce_epilogue()
#########################################################################
#########################ZeRO Partition Gradients########################
#########################################################################
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None
def initialize_gradient_partitioning_data_structures(self):
for i, param_group in enumerate(self.round_robin_fp16_groups):
total_partitions = dist.get_world_size(
group=self.real_dp_process_group[i])
self.param_to_partition_ids[i] = {}
self.is_partition_reduced[i] = {}
self.total_grads_in_partition[i] = {}
self.remaining_grads_in_partition[i] = {}
self.is_grad_computed[i] = {}
self.grad_partition_insertion_offset[i] = {}
self.grad_start_offset[i] = {}
self.first_param_index_in_partition[i] = {}
for partition_id in range(total_partitions):
self.is_grad_computed[i][partition_id] = {}
self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
self.total_grads_in_partition[i][partition_id] = 0
self.initialize_gradient_partition(
i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][
partition_id] = self.get_first_param_index(
i,
param_group,
partition_id)
def independent_gradient_partition_epilogue(self):
if self.verbose:
self.report_ipg_memory_usage(
f"In ipg_epilogue before reduce_ipg_grads", 0)
self.reduce_ipg_grads()
if self.verbose:
self.report_ipg_memory_usage(
f"In ipg_epilogue after reduce_ipg_grads", 0)
# if dist.get_rank() == 0:
# print()("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
if self.overlap_comm:
torch.cuda.synchronize()
# It is safe to clear previously reduced grads of other partitions
self._clear_previous_reduced_grads()
if self.cpu_offload is False:
for i, _ in enumerate(self.fp16_groups):
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=self.dtype,
device=torch.cuda.current_device(),
return_tensor_list=True)
else:
avg_new = self.get_flat_partition(self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=self.dtype,
device=torch.cuda.current_device(),
return_tensor_list=True)
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
accumulated_grad.add_(new_avg_grad)
self._release_ipg_buffers()
# No need to keep the gradients anymore.
# All gradients required by the step
# are in self.averaged_gradients
self.zero_grad()
if self.verbose:
report_memory_usage(f"End ipg_epilogue")
# resets all partition to no reduced
# sets remaining grads to the total number of grads in each partition
# set is grad computed to false for all grads in partition
def reset_partition_gradient_structures(self):
for i, _ in enumerate(self.fp16_groups):
total_partitions = dist.get_world_size(
group=self.real_dp_process_group[i])
for partition_id in range(total_partitions):
self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][
partition_id] = self.total_grads_in_partition[i][partition_id]
for param_id in self.is_grad_computed[i][partition_id]:
self.is_grad_computed[i][partition_id][param_id] = False
def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value):
if key in dictionary:
dictionary[key].append(value)
else:
dictionary[key] = [value]
def increment_value(dictionary, key):
if key in dictionary:
dictionary[key] += 1
else:
dictionary[key] = 1
partition_size = self.partition_size[i]
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for param in param_group:
param_size = param.numel()
param_id = self.get_param_id(param)
if (current_index >= start_index and current_index < end_index):
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][
param_id] = current_index - start_index
self.grad_start_offset[i][partition_id][param_id] = 0
elif start_index > current_index and start_index < (current_index +
param_size):
assert (
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
self.grad_start_offset[i][partition_id][param_id] = first_offset
current_index = current_index + param_size
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue()
def create_reduce_and_remove_grad_hooks(self):
self.grad_accs = []
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
if param.requires_grad:
def wrapper(param, i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(
param, i)
grad_acc.register_hook(
reduce_partition_and_remove_grads)
self.grad_accs.append(grad_acc)
wrapper(param, i)
def get_param_id(self, param):
unique_id = id(param)
return self.param_id[unique_id]
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (
100.0 * elem_count) // self.reduce_bucket_size
if self.verbose:
report_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
)
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(self, tensor_list, alignment):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
remaining = num_elements % alignment
if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
return self.flatten(padded_tensor_list)
############### Independent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
param.numel())
self.reduce_ipg_grads()
if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
param.numel())
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
if param.numel() > self.reduce_bucket_size:
self.extra_large_param_to_reduce = param
elif self.contiguous_gradients:
# keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
0,
self.elements_in_ipg_bucket,
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.elements_in_ipg_bucket += param.numel()
assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"
self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id))
# make sure the average tensor function knows how to average the gradients
if is_moe_param(param):
self.ipg_bucket_has_moe_params = True
self.report_ipg_memory_usage("End ipg_remove_grads", 0)
def print_rank_0(self, message):
if dist.get_rank() == 0 and self.verbose:
print(message)
def gradient_reduction_w_predivide(self, tensor):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.postscale_gradients:
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.gradient_predivide_factor != dp_world_size:
tensor_to_allreduce.mul_(
self.gradient_predivide_factor / dp_world_size)
else:
tensor_to_allreduce.div_(dp_world_size)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def average_tensor(self, tensor):
if self.overlap_comm:
torch.cuda.synchronize()
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
if not self.reduce_scatter:
self.gradient_reduction_w_predivide(tensor)
return
# Accumulate destination ranks and bucket offsets for each gradient slice.
# Note: potential future optimization, record access pattern of parameters
# in backward pass and partition gradients w.r.t. access pattern so that our
# bucket is guaranteed to be contiguous w.r.t. ranks
rank_and_offsets = []
real_dp_process_group = []
curr_size = 0
prev_id = -1
process_group = self.dp_process_group
# count = 0
for i, param, param_id in self.params_in_ipg_bucket:
process_group = self.dp_process_group
# Averages gradients at parameter level if ipg has a moe param
# Otherwise averaging is done at the entire buffer level at the end of the loop
if self.ipg_bucket_has_moe_params:
process_group = self.expert_dp_process_group if is_moe_param(
param) else self.dp_process_group
param.grad.data.div_(
dist.get_world_size(group=process_group))
partition_ids = self.param_to_partition_ids[i][param_id]
partition_size = self.partition_size[i]
# Get all partition ids + their offsets
partition_ids_w_offsets = []
for partition_id in partition_ids:
offset = self.grad_start_offset[i][partition_id][param_id]
partition_ids_w_offsets.append((partition_id, offset))
partition_ids_w_offsets.sort(key=lambda t: t[1])
# Calculate rank and offsets for grad slices
for idx in range(len(partition_ids_w_offsets)):
partition_id, offset = partition_ids_w_offsets[idx]
# if dist.get_rank() == 0 and count < 100:
# print(f"Rank {dist.get_rank()} rank offet id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}")
# count += 1
# Calculate numel for grad slice depending on partition location
if idx == len(partition_ids_w_offsets) - 1:
# Last partition_id uses its own offset
numel = param.numel() - offset
else:
# Set numel to next partition's offset
numel = partition_ids_w_offsets[idx + 1][1] - offset
# Merge bucket ranges if they belong to the same rank
if partition_id == prev_id:
prev_pid, prev_size, prev_numel = rank_and_offsets[-1]
rank_and_offsets[-1] = (prev_pid,
prev_size, prev_numel + numel)
else:
rank_and_offsets.append(
(partition_id, curr_size, numel))
real_dp_process_group.append(process_group)
curr_size += numel
prev_id = partition_id
if not self.ipg_bucket_has_moe_params:
tensor.div_(dist.get_world_size(group=self.dp_process_group))
async_handles = []
for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
# if dist.get_rank() == 0:
# print(f"Rank {dist.get_rank()} rank offet id {i} real dp size {dist.get_world_size(group=real_dp_process_group[i])} and dst: {dst}")
# dist.barrier()
# dist.barrier()
dst_rank = _get_global_rank(real_dp_process_group[i], dst)
async_handle = dist.reduce(grad_slice,
dst=dst_rank,
group=real_dp_process_group[i],
async_op=True)
async_handles.append(async_handle)
for handle in async_handles:
handle.wait()
##############################################################################
############################# CPU Offload Methods#############################
##############################################################################
def get_grad_position(self, group_id, tensor_list, first_offset, partition_size):
current_offset = 0
for i, tensor in enumerate(tensor_list):
param_id = self.get_param_id(tensor)
param_start_offset = 0
num_elements = tensor.numel()
tensor_offset = 0
# we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
param_start_offset = first_offset
# we dont need all elements of the tensor
if num_elements > (partition_size - current_offset):
num_elements = partition_size - current_offset
self.grad_position[param_id] = [
int(group_id),
int(param_start_offset),
int(current_offset),
int(num_elements)
]
current_offset += num_elements
def update_overflow_tracker_for_param_grad(self, param):
if param.grad is not None and self._has_inf_or_nan(param.grad.data):
self.local_overflow = True
def async_accumulate_grad_in_cpu_via_gpu(self, param):
param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
# copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
0,
0,
param.numel())
# buffer for storing gradients for this parameter in CPU
def buffer_to_accumulate_to_in_cpu():
if not self.fp16_master_weights_and_gradients:
return torch.zeros(param.numel(),
dtype=param.dtype,
device=self.device).pin_memory()
else:
return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
0,
dest_offset,
num_elements)
# accumulate gradients into param.grad or parts of it that belongs to this parittion
def accumulate_gradients():
if not self.fp16_master_weights_and_gradients:
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
non_blocking=True)
param.grad.data.view(-1).add_(dest_buffer)
else:
dest_buffer.narrow(0,
source_offset,
num_elements).copy_(
self.accumulated_grads_in_cpu[param_id].view(-1),
non_blocking=True)
param.grad.data.view(-1).narrow(
0,
source_offset,
num_elements).add_(dest_buffer.narrow(0,
source_offset,
num_elements))
# move accumulated gradients back to CPU
def copy_gradients_to_cpu():
if not self.fp16_master_weights_and_gradients:
self.accumulated_grads_in_cpu[param_id].data.copy_(
param.grad.data.view(-1),
non_blocking=True)
else:
self.accumulated_grads_in_cpu[param_id].data.copy_(
param.grad.data.view(-1).narrow(0,
source_offset,
num_elements),
non_blocking=True)
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu(
)
if self.micro_step_id > 0:
accumulate_gradients()
# at the boundary we will send 32bit directly
if not self.is_gradient_accumulation_boundary:
copy_gradients_to_cpu()
def set_norm_for_param_grad(self, param):
param_id = self.get_param_id(param)
accumulated_grad = self.accumulated_grads_in_cpu[
param_id] if self.gradient_accumulation_steps > 1 else param.grad
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
start = source_offset
accumulated_grad = accumulated_grad.view(
-1).narrow(0, start, num_elements)
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(
2)
def set_norm_for_param_grad_in_gpu(self, param):
param_id = self.get_param_id(param)
accumulated_grad = param.grad
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
start = source_offset
accumulated_grad = accumulated_grad.view(
-1).narrow(0, start, num_elements)
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(
2)
def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
0,
dest_offset,
num_elements)
src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements)
if not self.fp16_master_weights_and_gradients:
src_tensor = src_tensor.float()
dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None # offload only
def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
norm_type = 2.0
for p in params:
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_id = self.get_param_id(p)
# as some model have trainable parameters but skipped in training,
# their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,
# so they have no norm_for_param_grads
if param_id in self.norm_for_param_grads:
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item() ** 2
else:
# As unused parameters in modules may not be expected sometimes,
# add an explicit error msg when it occurred and an option to
# avoid the error
assert self.ignore_unused_parameters, """
This assert indicates that your module has parameters that
were not used in producing loss.
You can avoid this assert by
(1) enable ignore_unused_parameters option in zero_optimization config;
(2) making sure all trainable parameters and `forward` function
outputs participate in calculating loss.
"""
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
############################################################################################
def copy_grads_in_partition(self, param):
if self.cpu_offload:
if self.gradient_accumulation_steps > 1:
self.async_accumulate_grad_in_cpu_via_gpu(param)
if self.is_gradient_accumulation_boundary:
self.set_norm_for_param_grad_in_gpu(param)
self.update_overflow_tracker_for_param_grad(param)
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
return
# print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
if self.grads_in_partition is None:
self.grads_in_partition_offset = 0
total_size = 0
for group in self.params_in_partition:
for param_in_partition in group:
total_size += param_in_partition.numel()
if self.verbose:
report_memory_usage(
f"before copying {total_size} gradients into partition")
self.grads_in_partition = torch.empty(int(total_size),
dtype=self.dtype,
device=torch.cuda.current_device())
if self.verbose:
report_memory_usage(
f"after copying {total_size} gradients into partition")
# The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
new_grad_tensor = self.grads_in_partition.view(-1).narrow(
0,
self.grads_in_partition_offset,
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
# print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
self.grads_in_partition_offset += param.numel()
def reduce_ipg_grads(self):
if self.contiguous_gradients:
if self.extra_large_param_to_reduce is not None:
assert len(
self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen"
_, _, param_id = self.params_in_ipg_bucket[0]
assert self.get_param_id(
self.extra_large_param_to_reduce) == param_id, "param in ipg bucket does not match extra-large param"
self.average_tensor(
self.extra_large_param_to_reduce.grad.view(-1))
self.extra_large_param_to_reduce = None
else:
self.average_tensor(self.ipg_buffer[self.ipg_index])
else:
self.buffered_reduce_fallback(
None,
self.grads_in_ipg_bucket,
elements_per_buffer=self.elements_in_ipg_bucket)
if self.overlap_comm:
stream = self.reduction_stream
elif self.cpu_offload:
# TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed.
# torch.cuda.synchronize()
# stream = self.copy_grad_stream
stream = torch.cuda.current_stream()
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
self.params_already_reduced[param_id] = True
if self.partition_gradients:
if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear grads of other partitions during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param)
else:
param.grad = None # only if self.partition_gradients
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.ipg_bucket_has_moe_params = False
self.elements_in_ipg_bucket = 0
#####################################################################
def reduce_ready_partitions_and_remove_grads(self, param, i):
if self.partition_gradients or self.is_gradient_accumulation_boundary:
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
def zero_reduced_gradients(self, partition_id, i):
def are_all_related_partitions_reduced(params_id):
for partition_id in self.param_to_partition_ids[i][params_id]:
if not self.is_partition_reduced[i][partition_id]:
return False
return True
for params_id in self.is_grad_computed[i][partition_id]:
if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None # dead code
def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = self.flatten(tensors)
def print_func():
print(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
self.sequential_execution(print_func, message)
def get_grads_to_reduce(self, i, partition_id):
def get_reducable_portion(key):
grad = self.param_dict[key].grad
total_elements = grad.numel()
start = self.grad_start_offset[i][partition_id][key]
num_elements = min(
total_elements - start,
self.partition_size[i] -
self.grad_partition_insertion_offset[i][partition_id][key])
if not pg_correctness_test:
if num_elements == total_elements:
return grad
else:
return grad.contiguous().view(-1).narrow(0,
int(start),
int(num_elements))
else:
if num_elements == total_elements:
return grad.clone()
else:
return grad.clone().contiguous().view(-1).narrow(
0,
int(start),
int(num_elements))
grads_to_reduce = []
for key in self.is_grad_computed[i][partition_id]:
grad = get_reducable_portion(key)
grads_to_reduce.append(grad)
return grads_to_reduce
def sequential_execution(self, function, message, group=None):
if group is None:
group = self.dp_process_group
if dist.get_rank(group=group) == 0:
print(message)
for id in range(dist.get_world_size(group=group)):
if id == dist.get_rank(group=group):
function()
dist.barrier(group=group)
def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
param = self.param_dict[param_id]
if param.grad is None:
param.grad = torch.zero_like(param)
######################Reduction Related Methods##############################
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
rank = None
tensor = self.flatten(bucket)
tensor_to_allreduce = tensor
if pg_correctness_test:
allreduce_always_fp32 = True
if allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
tensor_to_allreduce.div_(
dist.get_world_size(group=self.dp_process_group))
if rank is None:
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
global_rank = _get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank,
group=self.dp_process_group)
if allreduce_always_fp32 and tensor is not tensor_to_allreduce:
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
tensor.copy_(tensor_to_allreduce)
return tensor
def _clear_previous_reduced_grads(self):
if self.previous_reduced_grads is not None:
for param in self.previous_reduced_grads:
param.grad = None # overlap enabled
self.previous_reduced_grads = None
# if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
if self.overlap_comm:
torch.cuda.synchronize()
# It is safe to clear the previously reduced grads of other partitions
self._clear_previous_reduced_grads()
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
allreduced = self.allreduce_bucket(
small_bucket, rank=rank, log=log)
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self,
bucket,
numel_per_bucket=500000000,
rank=None,
log=None):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket, rank=rank, log=None)
small_bucket = []
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, rank=rank, log=log)
# allows using reduction of gradients instead of using all_reduce
def buffered_reduce_fallback(self,
rank,
grads,
elements_per_buffer=500000000,
log=None):
split_buckets = split_half_float_double(grads)
for i, bucket in enumerate(split_buckets):
self.allreduce_no_retain(bucket,
numel_per_bucket=elements_per_buffer,
rank=rank,
log=log)
#############################################################################
#############################################################################
#############################################################################
# views the tensor as multiple partitions and returns
# those partitions
def get_data_parallel_partitions(self, tensor, group_id):
partitions = []
dp = dist.get_world_size(group=self.real_dp_process_group[group_id])
dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])
total_num_elements = tensor.numel()
base_size = total_num_elements // dp
remaining = total_num_elements % dp
start = 0
for id in range(dp):
partition_size = base_size
if id < remaining:
partition_size = partition_size + 1
partitions.append(tensor.narrow(0, start, partition_size))
start = start + partition_size
return partitions
def get_partition_info(self, tensor_list, partition_size, partition_id):
params_in_partition = []
params_not_in_partition = []
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for tensor in tensor_list:
tensor_size = tensor.numel()
if (current_index >= start_index and current_index < end_index):
params_in_partition.append(tensor)
elif start_index > current_index and start_index < (current_index +
tensor_size):
params_in_partition.append(tensor)
assert (
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
params_not_in_partition.append(tensor)
current_index = current_index + tensor_size
return params_in_partition, params_not_in_partition, first_offset
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None # epilogue and in step
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def _model_parallel_all_reduce(self, tensor, op):
""" Perform all reduce within model parallel group, if any.
"""
if self.model_parallel_group is None:
pass
else:
torch.distributed.all_reduce(tensor=tensor,
op=op,
group=self.model_parallel_group)
def clip_grad_norm(self, *args, **kwargs):
# dummy function to retain the same function interface
# as ColossalaiOptimizer for compatibility
pass
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from ``torch.nn.utils.clip_grad.clip_grad_norm_`` and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
# if dist.get_rank() == 0:
# print()(f"Total Norm begining {total_norm}")
for g, p in zip(gradients, params):
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_norm = g.data.double().norm(2)
total_norm += param_norm.item() ** 2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
# creates a flat fused tensor from the tensor list starting at the first_offset
# in the first tensor of the list. If there are not enough elements in the tensor
# list then the flat tensor will be padded with zeros
def get_flat_partition(self,
tensor_list,
first_offset,
partition_size,
dtype,
device,
return_tensor_list=False):
flat_tensor_list = []
current_size = 0
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
tensor.grad = torch.zeros_like(tensor)
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
# we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
# we dont need all elements of the tensor
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size
# we need a narrow view of the tensor based on the tensor offset and number of elements that
# we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)))
else:
flat_tensor_list.append(tensor)
current_size = current_size + num_elements
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size:
flat_tensor_list.append(
torch.zeros(int(partition_size - current_size),
dtype=dtype,
device=device))
if return_tensor_list:
return flat_tensor_list
return self.flatten(flat_tensor_list)
def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None # in step
def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False
def log_timers(self, timer_names):
if self.timers is None:
return
self.timers.log(names=list(timer_names))
def start_timers(self, timer_names):
if self.timers is None:
return
for name in timer_names:
self.timers(name).start()
def stop_timers(self, timer_names):
if self.timers is None:
return
for name in timer_names:
self.timers(name).stop()
def step(self, closure=None):
"""
Not supporting closure.
"""
self.micro_step_id = -1
if self.verbose:
report_memory_usage(f"In step before checking overflow")
# First compute norm for all group so we know if there is overflow
self.check_overflow(self.partition_gradients)
OPTIMIZER_ALLGATHER = 'optimizer_allgather'
OPTIMIZER_GRADIENTS = 'optimizer_gradients'
OPTIMIZER_STEP = 'optimizer_step'
timer_names = [OPTIMIZER_ALLGATHER,
OPTIMIZER_GRADIENTS, OPTIMIZER_STEP]
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
report_memory_usage('After overflow before clearing gradients')
self.zero_grad()
if self.cpu_offload:
self.reset_cpu_buffers()
else:
self.averaged_gradients = {}
if self.verbose:
report_memory_usage('After overflow after clearing gradients')
print(
"[deepspeed] fp16 dynamic loss scale overflow! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
self.start_timers(timer_names)
self.stop_timers(timer_names)
return
self.start_timers([OPTIMIZER_GRADIENTS])
norm_groups = []
single_partition_grad_groups = []
skip = False
for i, group in enumerate(self.fp16_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
# free gradients for all the prameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_in_partition[i])
# create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
if partition_id == dist.get_world_size(
group=self.real_dp_process_group[i]) - 1:
single_grad_partition = self.flatten_dense_tensors_aligned(
self.averaged_gradients[i],
int(self.partition_size[i])).to(
self.single_partition_of_fp32_groups[i].dtype)
else:
single_grad_partition = self.flatten(self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
assert single_grad_partition.numel() == self.partition_size[i], \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
single_grad_partition.numel(), self.partition_size[i], i, partition_id)
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(self.params_in_partition[i])
self.averaged_gradients[i] = None
single_partition_grad_groups.append(single_grad_partition)
if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
self.stop_timers([OPTIMIZER_GRADIENTS])
self.start_timers([OPTIMIZER_STEP])
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
fp16_param_groups = [
fp16_partitions[partition_id]
for fp16_partitions in self.parallel_partitioned_fp16_groups
]
self.optimizer.step(fp16_param_groups=fp16_param_groups)
else:
self.optimizer.step()
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups,
self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(
fp32_partition.data)
else:
self.optimizer.step()
# get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None # in step
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups,
self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
self.stop_timers([OPTIMIZER_STEP])
if self.cpu_offload:
self.reset_cpu_buffers()
self.start_timers([OPTIMIZER_ALLGATHER])
# gather the updated weights from everyone
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
# Sequential AllGather Best of both worlds
dp_world_size = dist.get_world_size(
group=self.real_dp_process_group[group_id])
num_shards = max(
1,
partitioned_params[partition_id].numel() * dp_world_size //
self.allgather_bucket_size)
shard_size = partitioned_params[partition_id].numel() // num_shards
num_elements = shard_size
assert shard_size * \
num_shards <= partitioned_params[partition_id].numel()
for shard_id in range(num_shards):
if shard_id == (num_shards - 1):
num_elements = partitioned_params[partition_id].numel(
) - shard_id * shard_size
shard_list = []
for dp_id in range(dp_world_size):
curr_shard = partitioned_params[dp_id].narrow(
0,
shard_id * shard_size,
num_elements).detach()
shard_list.append(curr_shard)
dist.all_gather(shard_list,
shard_list[partition_id],
group=self.real_dp_process_group[group_id])
self.stop_timers([OPTIMIZER_ALLGATHER])
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
self._update_model_fp16_weights(i)
self.log_timers(timer_names)
if self.verbose:
report_memory_usage('After zero_optimizer step')
return
def _average_expert_grad_norms(self, norm_groups):
for i, norm in enumerate(norm_groups):
if self.is_moe_param_group[i]:
scaled_norm = norm * 1.0 / float(
dist.get_world_size(group=self.ep_process_group))
scaled_norm_tensor = torch.tensor(scaled_norm,
device='cuda',
dtype=torch.float)
dist.all_reduce(scaled_norm_tensor,
group=self.ep_process_group)
norm_groups[i] = scaled_norm_tensor.item()
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm ** 2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def _check_overflow(self, partition_gradients=True):
self.overflow = self.has_overflow(partition_gradients)
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params, is_grad_list=False):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False
def has_overflow_partitioned_grads_serial(self):
for i in range(len(self.fp16_groups)):
for j, grad in enumerate(self.averaged_gradients[i]):
if grad is not None and self._has_inf_or_nan(grad.data, j):
return True
return False
def has_overflow(self, partition_gradients=True):
if partition_gradients:
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial(
)
overflow_gpu = torch.cuda.ByteTensor([overflow])
'''This will capture overflow across all data parallel and expert parallel process
Since expert parallel process are a subset of data parallel process'''
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
else:
params = []
for group in self.fp16_groups:
for param in group:
params.append(param)
overflow = self.has_overflow_serial(
params, is_grad_list=partition_gradients)
overflow_gpu = torch.cuda.ByteTensor([overflow])
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu,
op=torch.distributed.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, j=None):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
def backward(self, loss, retain_graph=False):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
self.micro_step_id += 1
if self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(int(self.reduce_bucket_size),
dtype=self.dtype,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm:
buf_1 = torch.empty(int(self.reduce_bucket_size),
dtype=self.dtype,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_1)
self.ipg_index = 0
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains a single flattened tensor.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
for i, group in enumerate(groups_with_padding):
lean_length = group.numel() - self.groups_padding[i]
groups_without_padding.append(group[:lean_length])
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
if torch.is_tensor(value):
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
else:
lean_state[key] = value
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
lean_optimizer_state = self._get_state_without_padding(
self.optimizer.state[p],
self.groups_padding[i])
optimizer_groups_state.append(lean_optimizer_state)
return optimizer_groups_state
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS
state_dict['partition_count'] = self.partition_count
state_dict['ds_version'] = version
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
self.single_partition_of_fp32_groups)
state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
# if self.cpu_offload:
# state_dict_tmp = async_copy_to(state_dict,
# 'cpu',
# torch.cuda.current_stream())
# state_dict = state_dict_tmp
return state_dict
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
merged_single_partition_of_fp32_groups = []
for i in range(len(self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
merged_partitions = [
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
]
flat_merged_partitions = self.flatten_dense_tensors_aligned(
merged_partitions,
self.nccl_start_alignment_factor *
dist.get_world_size(group=self.real_dp_process_group[i]))
dp_partitions = self.get_data_parallel_partitions(
flat_merged_partitions, i)
merged_single_partition_of_fp32_groups.append(
dp_partitions[partition_id])
for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
for group_id, fp16_partitions, fp32_partition in enumerate(
zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(
group=self.real_dp_process_group[group_id])
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id):
partition_id = dist.get_rank(
group=self.real_dp_process_group[group_id])
alignment = dist.get_world_size(
group=self.real_dp_process_group[group_id])
if torch.is_tensor(all_partition_states[0]):
flat_merged_partitions = self.flatten_dense_tensors_aligned(
all_partition_states,
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions,
group_id)
return dp_partitions[partition_id]
else:
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
return all_partition_states[0]
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, all_state_dict):
base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)):
partition_states = {}
all_partition_group_states = [
sd['base_optimizer_state'][i] for sd in all_state_dict
]
for key in all_partition_group_states[0].keys():
all_partition_states = [
all_states[key] for all_states in all_partition_group_states
]
partition_states[key] = self._partition_base_optimizer_state(
key,
all_partition_states,
i)
base_optimizer_group_states.append(partition_states)
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
if torch.is_tensor(self.optimizer.state[p][key]):
self.optimizer.state[p][key].data.copy_(saved.data)
else:
self.optimizer.state[p][key] = saved
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
r"""Loading ZeRO checkpoint
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
Note that the number of saved partitions may differ from number of loading partitions to support
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
"""
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict_list[0]['loss_scaler']
self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
self.overflow = state_dict_list[0]['overflow']
# zero stage 1 mode
if not self.partition_gradients:
required_version = pkg_version.parse("0.3.17")
ckpt_version = state_dict_list[0].get("ds_version", False)
error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \
"with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \
"please set 'legacy_stage1': true in your zero config json. This old version of " \
"stage 1 will be removed in v0.4.0."
assert ckpt_version, f"Empty ds_version! {error_str}"
assert required_version <= pkg_version.parse(
ckpt_version), f"Old version: {ckpt_version} {error_str}"
if load_optimizer_states:
self._restore_base_optimizer_state(state_dict_list)
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 1 if changing DP degree and option 2 otherwise.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
if load_from_fp32_weights:
self._restore_from_fp32_weights(state_dict_list)
else:
self._restore_from_fp16_weights()
def allreduce_gradients(self):
self.overlapping_partition_gradients_reduce_epilogue()
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
print(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
def estimate_zero2_model_states_mem_needs(total_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
additional_buffer_factor=1.5):
total_gpus = num_nodes * num_gpus_per_node
if cpu_offload:
gpu_mem = 2 * total_params
cpu_mem = total_params * \
max(4 * total_gpus, 16) * additional_buffer_factor
else:
gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
return int(cpu_mem), int(gpu_mem)
def model_to_params(model):
# shared params calculated only once
total_params = sum(
dict((p.data_ptr(),
p.numel()) for p in model.parameters()).values())
return total_params
def estimate_zero2_model_states_mem_needs_all_live(model,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
for a given ``model`` and hardware setup.
If you have an actual model object, use this function and everything will be derived
automatically.
If it's a hypothetical model, use ``estimate_zero2_model_states_mem_needs_all_cold`` where you have to pass
the ``total_params`` explicitly.
Args:
- ``model``: ``nn.Module`` object
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):
"""
total_params = model_to_params(model)
estimate_zero2_model_states_mem_needs_all_cold(
total_params=total_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
additional_buffer_factor=additional_buffer_factor)
def estimate_zero2_model_states_mem_needs_all_cold(total_params,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
for a given ``model`` and hardware setup.
If it's a hypothetical model, use this function where you have to pass
the ``total_params`` and ``largest_layer_params`` explicitly.
If you have an actual model object, use ``estimate_zero2_model_states_mem_needs_all_live`` and everything
will be derived automatically.
Args:
- ``total_params``: total model params
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):
"""
def format_options(cpu_offload):
enabled = []
enabled.append(f"cpu_offload={1 if cpu_offload else 0}")
return ", ".join(enabled)
nodes_str = "nodes" if num_nodes > 1 else "node"
gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
print(
"Estimated memory needed for params, optim states and gradients for a:\n"
f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
f"SW: Model with {int(total_params / 1e6)}M total params.")
print(" per CPU | per GPU | Options")
for cpu_offload in [True, False]:
cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(
total_params=total_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
cpu_offload=cpu_offload,
additional_buffer_factor=additional_buffer_factor
)
options_str = format_options(cpu_offload=cpu_offload)
print(
f" {cpu_mem / 2 ** 30:7.2f}GB | {gpu_mem / 2 ** 30:6.2f}GB | {options_str}")
"""
"Copyright 2020 The Microsoft DeepSpeed Team.
Licensed under the MIT license.
"""
import math
from collections import OrderedDict
import torch
import torch.distributed as dist
try:
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id, debug_param2name_id_numel, \
debug_param2name_id_shape_device, debug_module2name_class
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper
from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper
from deepspeed.runtime.utils import is_model_parallel_parameter
from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partition_parameters import _init_external_params
except ImportError:
pass
from torch._six import inf
from torch.distributed.distributed_c10d import _get_global_rank
from torch.optim import Optimizer
from colossalai.core import global_context as gpc
from colossalai.utils import report_memory_usage
from .loss_scaler import LossScaler, DynamicLossScaler
from colossalai.context import ParallelMode
# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
FWD_MODULE_STACK = list()
def print_rank_0(message, debug=False, force=False):
rank = torch.distributed.get_rank()
if rank == 0 and (debug or force):
print(message)
# other variations
# - print for all ranks w/o interleaving
# printflock(f"[{rank}] {message}")
# - print to log file per rank
# log_rank_file(rank, message)
def input(msg):
return
def split_half_float_double(tensors):
dtypes = [
"torch.cuda.HalfTensor",
"torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor"
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def isclose(a, b, rtol=1e-09, atol=0.0):
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
def get_all_parameters(sub_module, recurse=False):
return itertools.chain(sub_module.named_parameters(recurse=recurse),
sub_module.ds_external_parameters())
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module,
functional,
backward_function,
output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
# for each tensor in outputs run the forward_funciton and register backward_function as hook
def _apply_forward_and_backward_to_tensors_only(module,
forward_function,
backward_function,
outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_forward_and_backward_to_tensors_only(
module,
forward_function,
backward_function,
output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
forward_function(outputs)
if outputs.requires_grad:
outputs.register_hook(backward_function)
return outputs
else:
return outputs
class ZeROOrderedDict(OrderedDict):
def __init__(self, parent_module, *args, **kwargs):
"""A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
Args:
parent_module (``collections.OrderedDict``): the collection to replace
"""
super().__init__(*args, **kwargs)
self._parent_module = parent_module
self._in_forward = False
def __getitem__(self, key):
param = super().__getitem__(key)
# Params can be registered as None (e.g., bias)
if param is None:
return param
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if self._parent_module._parameters._in_forward:
print_rank_0(f'Registering external parameter from getter {key}',
force=False)
register_external_parameter(FWD_MODULE_STACK[-1], param)
param.all_gather()
return param
def _inject_parameters(module, cls):
for module in module.modules():
if cls == ZeROOrderedDict:
new_param = cls(parent_module=module)
else:
new_param = cls()
for key, param in module._parameters.items():
new_param[key] = param
module._parameters = new_param
# TODO Needs to be implemented
class PrefetchCoordinator(object):
def __init__(self):
# step_id keeps track of the number of sub-modules invoked so far
# the step_id is tracking forward and backward sequence of sub-modules
self.step_id = 0
# stores the sequence of sub modules in forward+backward pass
self.sub_module_trace = []
# maps sub_module id to submodule objects
self.id_to_sub_module_map = {}
# stores the total number of parmeters in each sub_module
self.id_to_sub_module_size_map = {}
self.trace_completed = False
self.most_recent_sub_module_step = {}
# reuse distances
self.reuse_numel_for_step_id = {}
def record_trace(self, sub_module):
if not self.trace_completed:
self.sub_module_trace.append(sub_module.id)
self.id_to_sub_module_map[sub_module.id] = sub_module
def print_trace(self):
print_rank_0(
f"The module trace is : {[self.id_to_sub_module_map[module_id].id for module_id in self.sub_module_trace]}"
)
def increment_step(self, sub_module):
self.most_recent_sub_module_step[sub_module.id] = self.step_id
self.step_id += 1
def reset_step(self):
self.step_id = 0
# returns the next numel parameters that will be used next but are not available or inflight
def get_params_to_prefetch(self, sub_module, numel=2000000):
# numel_in_sub_module = 0
# for name, param in sub_module.named_parameters(recurse=False):
# numel_in_sub_module += param.ds_numel
# #if numel_in_sub_module < (numel // 2):
# return []
# tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing
if sub_module.id != self.sub_module_trace[self.step_id]:
print_rank_0(
f"Tracing failed. Prefetching is disabled at sub-module: {debug_module2name_id(sub_module)}"
)
return []
params_to_prefetch = []
total_numel_to_prefetch = 0
for i in range(self.step_id, len(self.sub_module_trace)):
module_id = self.sub_module_trace[i]
for _, param in get_all_parameters(self.id_to_sub_module_map[module_id]):
if param.ds_status is ZeroParamStatus.NOT_AVAILABLE and (
param.ds_id not in [p.ds_id for p in params_to_prefetch]):
params_to_prefetch.append(param)
total_numel_to_prefetch += param.ds_numel
# print_rank_0(f"Total numel to prefetch: {total_numel_to_prefetch}. Param: {param.ds_shape} and numel {param.ds_numel}, numel limit {numel}")
# and total_numel_to_prefetch > (numel_in_sub_module // 2):
if total_numel_to_prefetch >= numel:
return params_to_prefetch
return params_to_prefetch
# checks if this sub_module will be used again and if so then returns the number of elements
# in the parameters used between this sub_module and the reuse of this sub_module
def get_reuse_distance_in_numel(self, sub_module, sub_module_step_id=None):
# assert is_forward is not None, "is_forward must be set to True for Forward Propagation and False for backward Propagation"
is_there_reuse = False
reuse_distance_in_numel = 1000000000000
# set the appropriate trace
trace = self.sub_module_trace
total_steps = len(trace)
if sub_module_step_id is None:
sub_module_step_id = self.most_recent_sub_module_step[sub_module.id]
# tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing
if sub_module.id != trace[sub_module_step_id]:
print_rank_0(
f"Tracing failed. Cannot tell if the sub_module: {sub_module.id} is reused"
)
return reuse_distance_in_numel
# return cached value
if sub_module_step_id in self.reuse_numel_for_step_id:
return self.reuse_numel_for_step_id[sub_module_step_id]
start_step = self.step_id
print_rank_0(f"Step id is {self.step_id} ")
for step_id in range(start_step, total_steps):
print_rank_0(
f"Trace id {trace[step_id]} and sub_module id {sub_module.id}")
if sub_module.id == trace[step_id]:
end_step = step_id
is_there_reuse = True
reuse_distance_in_numel = self._distance_in_numel(
start_step,
end_step,
trace)
break
self.reuse_numel_for_step_id[sub_module_step_id] = reuse_distance_in_numel
return reuse_distance_in_numel
def _distance_in_numel(self, start_step, end_step, trace):
distance_in_numel = 0
for step_id in range(start_step, end_step):
module_id = trace[step_id]
for _, param in self.id_to_sub_module_map[module_id].named_parameters(recurse=False):
distance_in_numel += param.ds_numel
for _, param in self.id_to_sub_module_map[module_id].ds_external_parameters():
distance_in_numel += param.ds_numel
return distance_in_numel
class PartitionedParameterCoordinator(object):
def __init__(self,
comm_stream=None,
max_reuse_distance_in_numel=500000000,
max_available_parameters_in_numel=700000000):
self.in_flight_handles = []
self.params_in_flight = []
self.comm_stream = comm_stream if comm_stream is not None else torch.cuda.current_stream(
)
self.prefetch_coordinator = PrefetchCoordinator()
self.hierarchy = 0
self.total_available_parameter_numel = 0
self.max_available_parameters_in_numel = max_available_parameters_in_numel
# max distance between two use of the module beyond which module is released
self.max_reuse_distance_in_numel = max_reuse_distance_in_numel
def _increment_available_parameter_numel(self, increment):
self.total_available_parameter_numel += increment
def _decrement_available_parameter_numel(self, decrement):
self.total_available_parameter_numel -= decrement
'''-----------------------Tracing and Prefetching ---------------'''
def record_trace(self, sub_module):
self.prefetch_coordinator.record_trace(sub_module)
def finish_tracing(self, print_trace=False):
self.prefetch_coordinator.trace_completed = True
if print_trace:
self.prefetch_coordinator.print_trace()
# swap in parameter partitions from nvme for those parameters that will be used
# after the ones that are already being prefetched into full parameters
def _prefetch_nvme_param_partitions(self, sub_module, params_in_flight):
numel_in_flight = sum(
[param.ds_tensor.ds_numel for param in params_in_flight])
upcoming_param_list = self.prefetch_coordinator.get_params_to_prefetch(
sub_module,
numel=2 * numel_in_flight)
swap_in_params = []
for param in upcoming_param_list:
if len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers():
break
if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
swap_in_params.append(param)
if len(swap_in_params) > 0:
swap_in_params[0].nvme_swapper.swap_in(
swap_in_params, async_op=True)
# Pre fetches the parameters for sub_modules that comes after
# the current sub_module. This call is asynchronous
def prefetch_next_sub_modules(self, sub_module, numel=5000000, nvme=False):
params_to_prefetch = []
if not self.prefetch_coordinator.trace_completed:
return params_to_prefetch
# prefetch if there is no current prefetching in flight
if not self.in_flight_handles and self.total_available_parameter_numel < self.max_available_parameters_in_numel:
params_to_prefetch = self.prefetch_coordinator.get_params_to_prefetch(
sub_module,
numel=numel)
self._all_gather(params_to_prefetch, async_op=True)
for param in params_to_prefetch:
param.ds_status = ZeroParamStatus.INFLIGHT
# keeping track of number of elements consumed by available parmaeters
self._increment_available_parameter_numel(param.ds_numel)
if nvme:
self._prefetch_nvme_param_partitions(
sub_module, params_to_prefetch)
self._print_prefetch_elements_info(sub_module, params_to_prefetch)
print_rank_0(
f"{'--' * self.hierarchy}--PreFetching parameters {[param.ds_id for param in params_to_prefetch]} and available {self.total_available_parameter_numel}, max limit {self.max_available_parameters_in_numel}",
force=False)
def _print_prefetch_elements_info(self, sub_module, params_to_prefetch):
sub_module_numel = 0.0
for name, param in sub_module.named_parameters(recurse=False):
sub_module_numel += param.ds_numel
numel_being_prefetched = 0
for param in params_to_prefetch:
numel_being_prefetched = param.ds_numel
print_rank_0(
f"{'--' * self.hierarchy}--PreFetching {numel_being_prefetched} numels and number of numel in the next sub module is {sub_module_numel}",
force=False)
def increment_step(self, sub_module):
self.prefetch_coordinator.increment_step(sub_module)
def reset_step(self):
self.prefetch_coordinator.reset_step()
'''----------------------------------------------------------------------'''
# Fetches the parameters in the sub_module
# This call is blocking
def fetch_sub_module(self, sub_module):
partitioned_params = []
params_in_flight = False
print_rank_0(
f"{'--' * self.hierarchy}Fetching params in module {debug_module2name_class(sub_module)}"
)
params_to_fetch = [
param for _,
param in sub_module.named_parameters(recurse=False)
]
# print([n for n,p in sub_module.named_parameters(recurse=False)])
if hasattr(sub_module, 'ds_external_parameters'):
print_rank_0(
f"{'--' * self.hierarchy}--Fetching external parameters {sub_module.ds_external_parameters()}"
)
params_to_fetch += [
param for _,
param in sub_module.ds_external_parameters()
]
# for _, param in sub_module.named_parameters(recurse=False):
for param in params_to_fetch:
param.ds_active_sub_modules += 1
print_rank_0(
f"{'--' * self.hierarchy}--Fetching parameters {debug_param2name_id_shape(param)} with active sub modules {param.ds_active_sub_modules}"
)
if param.ds_status == ZeroParamStatus.AVAILABLE:
print_rank_0(
f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is already available"
)
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
print_rank_0(
f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is being fetched"
)
partitioned_params.append(param)
# keeping track of number of elements consumed by available parmaeters
self._increment_available_parameter_numel(param.ds_numel)
print_rank_0(f"Incrementing with parameter id {param.ds_id}")
if param.ds_status == ZeroParamStatus.INFLIGHT:
params_in_flight = True
print_rank_0(
f"{'--' * self.hierarchy}--Parameters {debug_param2name_id(param)} is already in flight (prefetched)"
)
self.hierarchy += 1
# parameters are partitioned and need to be allgathered
self._all_gather(partitioned_params, async_op=True)
# parameters are inflight and communication needs to be completed
if partitioned_params or params_in_flight:
self._synchronize_communication()
for _, param in sub_module.named_parameters(recurse=False):
param.ds_status = ZeroParamStatus.AVAILABLE
print_rank_0(
f"Param {debug_param2name_id_shape_device(param)} norm={param.norm()}",
force=False)
# print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}")
def release_sub_module(self, sub_module):
self.hierarchy -= 1
print_rank_0(
f"{'--' * self.hierarchy}Releasing params in module {debug_module2name_class(sub_module)}"
)
params_to_release = [
param for _,
param in sub_module.named_parameters(recurse=False)
]
if hasattr(sub_module, 'ds_external_parameters'):
# print_rank_0(f"Releasing external parameters {sub_module.ds_external_parameters()}")
params_to_release += [
param for _,
param in sub_module.ds_external_parameters()
]
# for _, param in sub_module.named_parameters(recurse=False):
for param in params_to_release:
param.ds_active_sub_modules -= 1
if not param.ds_active_sub_modules and not self._keep_for_later(
sub_module) and not param.ds_persist:
print_rank_0(
f"{'--' * self.hierarchy}--Releasing parameter {debug_param2name_id_numel(param)} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}",
force=False)
# Keeping track of number of elements that are consumed by available parameters
self._decrement_available_parameter_numel(param.ds_numel)
# report_memory_usage(
# f"Before releasing param {debug_param2name_id_numel(param)}",
# )
param.partition(hierarchy=self.hierarchy)
# report_memory_usage(
# f"After releasing param {debug_param2name_id_numel(param)}",
# )
param.ds_status = ZeroParamStatus.NOT_AVAILABLE
else:
print_rank_0(
f"{'--' * self.hierarchy}--Did not release param {debug_param2name_id_numel(param)} with active sub modules {param.ds_active_sub_modules}, keep for later={self._keep_for_later(sub_module)} and persistence={param.ds_persist}",
force=False)
def release_and_reset_parameter(self, param):
param.ds_active_sub_modules = 0
if param.ds_status == ZeroParamStatus.AVAILABLE:
print_rank_0(
f"Releasing unpartitioned param {debug_param2name_id_numel(param)} active sub-modules {param.ds_active_sub_modules} and persisitence {param.ds_persist}"
)
self._decrement_available_parameter_numel(param.ds_numel)
param.partition()
def _keep_for_later(self, sub_module):
if not self.prefetch_coordinator.trace_completed:
return False
if self.max_reuse_distance_in_numel == 0:
return False
reuse_distance_in_numel = self.prefetch_coordinator.get_reuse_distance_in_numel(
sub_module)
# print_rank_0(f"Reuse distance and numel for sub_module id {sub_module.id} is {reuse_distance_in_numel}")
return reuse_distance_in_numel < self.max_reuse_distance_in_numel
def _all_gather(self, partitioned_params, async_op=False):
with torch.cuda.stream(self.comm_stream):
handles = partitioned_params[0].all_gather(
param_list=partitioned_params,
async_op=async_op,
hierarchy=self.hierarchy) if partitioned_params else None
if handles is not None:
self.in_flight_handles.extend(handles)
self.params_in_flight.extend(partitioned_params)
def _synchronize_communication(self, synchronize_streams=True):
assert len(self.params_in_flight) == len(self.in_flight_handles)
for handle, param in zip(self.in_flight_handles, self.params_in_flight):
if handle is not None:
with torch.cuda.stream(self.comm_stream):
handle.wait()
param.ds_status = ZeroParamStatus.AVAILABLE
self.comm_stream.synchronize()
torch.cuda.synchronize() if synchronize_streams else None
self.in_flight_handles = []
self.params_in_flight = []
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
if not hasattr(module, "applied_pre_backward_ref_cnt"):
module.applied_pre_backward_ref_cnt = 0
module.applied_pre_backward_ref_cnt += 1
# print(f"After Forward: {ctx.module.__class__.__name__}")
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
# print(f"Before Backward: {ctx.module.__class__.__name__}")
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
if output.requires_grad:
# TODO SOME TIMES post backward does not report_memory_usage()ered debug in detail
# Should only cause increase in memory not correctness issue
# if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
# assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
# if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.pre_backward_function = pre_backward_function
output = output.detach()
return output
@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.pre_backward_function(ctx.module)
# print(f"After Backward: {ctx.module.__class__.__name__}")
return (None, None) + args
INITIAL_MICRO_STEP_ID = -1
class ZeroRedundancyOptimizer_Level_3(Optimizer):
"""
ZeroRedundancyOptimizer_Level_3 designed to reduce the memory footprint
required for training large deep learning models.
For more details please report_memory_usage() Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
"""
def __init__(self,
module,
init_optimizer,
dp_paralllel_mode=ParallelMode.DATA,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=False,
contiguous_gradients=True,
reduce_bucket_size=500000000,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
reduce_scatter=True,
overlap_comm=False,
offload_optimizer_config=None,
offload_param_config=None,
sub_group_size=1000000000000,
clip_grad=0.0,
allreduce_always_fp32=False,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
aio_config=None,
dtype=torch.half):
# mpu = None
# mpu is removed from the parameter list
# tensor parallel will be automatically detected later
# LSG: default parameter for compatibility
elastic_checkpoint = False
timers = None
dp_process_group = gpc.get_group(dp_paralllel_mode)
self.verbose = verbose
# LSG: in deepspeed deepspeed/runtime/zero/partition_parameters.py,
# self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
# the local device is obtained by env var LOCAL_RANK, thus, need to change this
# env var on the spot as LOCAL_RANK may not be present
if not 'LOCAL_RANK' in os.environ:
device_id = gpc.get_global_rank() % torch.cuda.device_count()
os.environ['LOCAL_RANK'] = str(device_id)
# self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
if self.verbose:
report_memory_usage("Stage 3 initialize beginning")
if dist.get_rank() == 0:
print(f"Reduce bucket size {reduce_bucket_size}")
print(f"Allgather bucket size {prefetch_bucket_size}")
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
self.defaults = init_optimizer.defaults
# Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten
self.dtype = dtype
if not all(is_zero_param(p) for p in module.parameters()):
ds_config = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"zero_optimization": {
"offload_param": offload_param_config,
"offload_optimizer": offload_optimizer_config,
},
"aio": aio_config
}
if offload_param_config is not None:
remote_device = offload_param_config['device']
else:
remote_device = None
if offload_optimizer_config is not None:
pin_memory = offload_optimizer_config.get(OFFLOAD_OPTIMIZER_PIN_MEMORY, False)
else:
pin_memory = False
group = None
if gpc.is_initialized(ParallelMode.DATA):
group = gpc.get_group(ParallelMode.DATA)
Init(module=module, data_parallel_group=group, dtype=self.dtype,
remote_device=remote_device, config_dict_or_path=ds_config,
pin_memory=pin_memory)
for m in module.modules():
_init_external_params(m)
self.module = module
self.elastic_checkpoint = elastic_checkpoint
self.overlap_comm = overlap_comm
# Replace ._parameters with a new class to enable auto-registration of
# external parameters
_inject_parameters(module, ZeROOrderedDict)
if self.overlap_comm:
self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda()
###################### offload optimizer setup ##################################
self.optimizer_swapper = None
self.swap_optimizer = False
self.offload_optimizer = False
self.offload_optimizer_pin_memory = False
self.offload_optimizer_fast_init = False
if offload_optimizer_config is not None:
self.offload_optimizer = True
self.offload_optimizer_pin_memory = offload_optimizer_config[
OFFLOAD_OPTIMIZER_PIN_MEMORY]
self.swap_optimizer = offload_optimizer_config[
OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE
self.offload_optimizer_fast_init = offload_optimizer_config[
OFFLOAD_OPTIMIZER_FAST_INIT]
###################### offload param setup ##################################
self.offload_param = False
self.offload_param_pin_memory = False
self.params_in_nvme_and_cpu = False
self.max_params_in_cpu = 0
if offload_param_config is not None:
assert self.offload_optimizer, "parameter offload is only available with optimizer state offload"
self.offload_param = True
self.offload_param_pin_memory = offload_param_config[
OFFLOAD_PARAM_PIN_MEMORY]
self.params_in_nvme_and_cpu = offload_param_config[
OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE
self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU]
if self.verbose:
print_rank_0(
f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}",
force=False)
self.deepspeed_adam_offload = (self.offload_optimizer
and type(init_optimizer) == DeepSpeedCPUAdam)
self.device = torch.cuda.current_device(
) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE
############################################################################
if self.verbose:
report_memory_usage("Before Partitioned Parameter Coordinator")
fetch_stream = torch.cuda.Stream() if self.overlap_comm else None
self.param_coordinator = PartitionedParameterCoordinator(
comm_stream=fetch_stream,
max_reuse_distance_in_numel=int(max_reuse_distance),
max_available_parameters_in_numel=int(max_live_parameters))
if self.verbose:
report_memory_usage("After Partitioned Parameter Coordinator")
# self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream())
# -------------Stage 3 Setup-------------------#
# parameters smaller than the threshold will be collectively gathered at the
# end of the optimizer step and will be kept till the end of the backward pass
# TODO maybe worth just replicating these parameters and doing all reduce for them
self.persistence_threshold = int(param_persistence_threshold)
self.persistent_parameters = self.persistent_parameters()
self.setup_zero_stage3_hooks()
# resetting ds_tensor just in case parameters have been changed after initialization
# example .half() or .to()
# self.reset_ds_tensor()
# ---------------------------------------------#
self.timers = timers
self.reduce_scatter = reduce_scatter
self.dp_process_group = dp_process_group
self.partition_count = dist.get_world_size(group=self.dp_process_group)
if gpc.is_initialized(ParallelMode.TENSOR) is None:
self.model_parallel_group = None
self.model_parallel_rank = 0
else:
self.model_parallel_group = gpc.get_group(ParallelMode.TENSOR)
self.model_parallel_rank = gpc.get_local_rank(ParallelMode.TENSOR)
self.overflow = False
self.clip_grad = clip_grad
self.allreduce_always_fp32 = allreduce_always_fp32
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = INITIAL_MICRO_STEP_ID
if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
# Holds the mode parameter
# The param.data may not hold any meaningful data
# when param's status is NOT_AVAILABLE or IN_FLGHT
self.fp16_groups = []
# Hold partitioned parameters
self.fp16_partitioned_groups = []
# Holds a fused and flattened copy of the parameters
self.fp16_partitioned_groups_flat = []
self.fp16_partitioned_groups_flat_numel = []
# defragmented pinned memory
self.param_groups_fp16_flat_cpu_memory = []
# a single 32-bit partition of the parallel partitioned parameters
# that this process will update
self.fp32_partitioned_groups_flat = []
self.next_swappable_fp32_partitioned_groups = []
# number of elements per partition in each group
self.partition_size = []
self.all_reduce_print = False
self.prefetch_elements = int(prefetch_bucket_size)
# padding on each partition for alignment purposes
self.groups_padding = []
self.sub_group_size = sub_group_size
self.sub_group_to_group_id = {}
if self.verbose:
report_memory_usage("Before creating fp16 partitions")
self._create_fp16_partitions_with_defragmentation()
num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
if self.verbose:
report_memory_usage(
f"After creating fp16 partitions: {num_fp16_subgroups}")
# Optimizer ensor swapping
if self.swap_optimizer:
self._configure_tensor_swapping(
offload_optimizer_config, aio_config)
if self.verbose:
report_memory_usage("Before creating fp32 partitions")
self._create_fp32_partitions()
if self.verbose:
report_memory_usage("After creating fp32 partitions")
dist.barrier()
# To support pipelined optimizer swapping
self._create_next_swappable_fp32_groups()
if self.verbose:
report_memory_usage("Before initializing optimizer states")
self.initialize_optimizer_states()
if self.verbose:
report_memory_usage("After initializing optimizer states")
dist.barrier()
if dist.get_rank() == 0 and self.verbose:
print(f"optimizer state initialized")
self.reduce_bucket_size = int(reduce_bucket_size)
self.reduction_event = torch.cuda.Event(
enable_timing=False, blocking=False)
self.reduction_stream = torch.cuda.Stream(
) if self.overlap_comm else torch.cuda.current_stream()
self.callback_queued = False
self.copy_grad_stream = torch.cuda.Stream()
self.param_dict = {}
# map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
self.contiguous_gradients = contiguous_gradients
self.extra_large_param_to_reduce = None
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
self.params_already_reduced = []
self.is_gradient_accumulation_boundary = True
self._release_ipg_buffers()
self.previous_reduced_grads = None
# simplified param id
self.param_id = {}
count = 0
for i, params_group in enumerate(self.fp16_groups):
for param in params_group:
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
count = count + 1
# Largest partitioned param
largest_partitioned_param_numel = max([
max([tensor.numel() for tensor in fp16_partitioned_group])
for fp16_partitioned_group in self.fp16_partitioned_groups
])
if self.verbose:
print_rank_0(
f'Largest partitioned param numel = {largest_partitioned_param_numel}',
force=False)
if self.verbose:
report_memory_usage(f"Before Set Grad positions")
self.grad_position = {}
self.set_grad_positions()
if self.verbose:
report_memory_usage(f"Before CPU Offload initialization")
self.grads_in_partition = None
if self.offload_optimizer:
self.accumulated_grads_in_cpu = {}
self.norm_for_param_grads = {}
self.local_overflow = False
self.temp_grad_buffer_for_gpu_offload = torch.zeros(
largest_partitioned_param_numel,
device=torch.cuda.current_device(),
dtype=self.dtype)
self.temp_grad_gpu_buffer = torch.zeros(largest_partitioned_param_numel,
device=torch.cuda.current_device(),
dtype=self.dtype)
if self.verbose:
report_memory_usage(f"After CPU Offload initialization")
# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
# stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
# will store the averaged gradients required by this parititon
self.averaged_gradients = {}
# creates backward hooks for gradient partitioning
self.create_reduce_and_remove_grad_hooks()
# exit(0)
# we may have a way of fusing dynamic scale. Do not support for now
if self.dtype == torch.float or not dynamic_loss_scale:
loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=loss_scale_value)
cur_iter = 0
else:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
self.debug_fp16_grads = [{} for _ in self.fp16_groups]
if dist.get_rank(group=self.dp_process_group) == 0 and self.verbose:
report_memory_usage(f"After initializing ZeRO optimizer")
def _configure_tensor_swapping(self, offload_optimizer_config, aio_config):
nvme_swap_folder = os.path.join(
offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH],
'zero_stage_3')
os.makedirs(nvme_swap_folder, exist_ok=True)
if torch.distributed.get_rank() == 0 and self.verbose:
print(f'Tensor Swapping: Adding optimizer tensors')
swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[
OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper
self.optimizer_swapper = swapper_type(
swap_config=offload_optimizer_config,
aio_config=aio_config,
base_folder=nvme_swap_folder,
optimizer=self.optimizer,
largest_numel=max(self.fp16_partitioned_groups_flat_numel),
device=self.device,
dtype=torch.float32,
timers=self.timers)
def _create_fp16_partitions(self):
dist.barrier()
partition_id = dist.get_rank(group=self.dp_process_group)
# loop to deal with groups
for j, param_group in enumerate(self.optimizer.param_groups):
sub_groups = self._create_fp16_sub_groups(param_group['params'])
for sub_group in sub_groups:
i = len(self.fp16_groups)
# push this group to list before modify
self.fp16_groups.append(sub_group)
self.sub_group_to_group_id[i] = j
# These are the list of the partitioned parameters
self.fp16_partitioned_groups.append(
[param.ds_tensor for param in self.fp16_groups[i]])
if self.verbose:
print_rank_0(
f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}"
)
# Record padding required to align group to world size (only applies to last rank)
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
padding = [p.padding_size() for p in self.fp16_groups[i]]
else:
padding = [0] * len(self.fp16_groups[i])
self.groups_padding.append(padding)
# not sure why apex was cloning the weights before flattening
# removing cloning here
if self.verbose:
report_memory_usage(f"Before Flattening param group {i}")
if not self.offload_param:
if self.verbose:
report_memory_usage(
f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_partitioned_groups[i])
if self.verbose:
report_memory_usage(
f"After moving param group {i} to CPU")
# create flat buffer in CPU and move to GPU
self.fp16_partitioned_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
if self.verbose:
report_memory_usage(
f"After flattening and moving param group {i} to GPU"
)
else:
# Without the detach, report_memory_usage()lattening becomes part of the
# model graph causing errors downstream
self.fp16_partitioned_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
dist.get_world_size(
group=self.dp_process_group)).detach().pin_memory())
if self.verbose:
report_memory_usage(f"After Flattening param group {i}")
# set model fp16 weight to slices of flattened buffer
updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i],
self.fp16_partitioned_groups[i])
for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params):
partitioned_param.data = q.data
def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False):
'''If flat buffer is None then the parameters in the param_list are
not copied to the flat buffer. This is because they excede the number of max_params_in_cpu
Some of these parameters may aready be in CPU in unflattened buffers
or they maybe in GPU, or they maybe in NVME. If they are in NVME, then
they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are
needed during training.'''
if flat_buffer is None:
# this dst buffer is on NVMe, so skip this
return
start = 0
for param in param_list:
src = param.ds_tensor
dest = flat_buffer.narrow(0, start, src.ds_numel)
start = start + src.ds_numel
'''if the parameter was initialized in nvme then bring it to the destination buffer directly'''
if src.status == PartitionedParamStatus.NOT_AVAILABLE:
if self.verbose:
print_rank_0(
f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU"
)
param.nvme_swapper.swap_into_buffer(param, dest)
src.data = dest.data
src.status = PartitionedParamStatus.AVAILABLE
else:
assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Parm must be avialable here"
if not avoid_copy:
dest.data.copy_(src.data)
src.data = dest.data
# Final location must be gpu/cpu in this case
param.ds_tensor.final_location = 'not-nvme'
def _create_param_groups_fp16_flat_cpu_memory(self):
aggregate_params_count = 0
for j, param_group in enumerate(self.optimizer.param_groups):
params_in_group = sum(
[p.ds_tensor.ds_numel for p in param_group['params']])
flat_buffer_size = params_in_group
if self.params_in_nvme_and_cpu and \
aggregate_params_count + params_in_group > self.max_params_in_cpu:
flat_buffer_size = max(0,
self.max_params_in_cpu - aggregate_params_count)
aggregate_params_count += params_in_group
if flat_buffer_size > 0:
if self.verbose:
print_rank_0(f"group {j} flat buffer size {flat_buffer_size}",
force=False)
self.param_groups_fp16_flat_cpu_memory.append(
torch.empty(int(flat_buffer_size),
dtype=self.dtype,
pin_memory=True))
else:
if self.verbose:
print_rank_0(
f"No flat buffer size. Param group size was {params_in_group}",
force=False)
self.param_groups_fp16_flat_cpu_memory.append(
torch.empty(1,
dtype=self.dtype))
def _create_fp16_partitions_with_defragmentation(self):
dist.barrier()
partition_id = dist.get_rank(group=self.dp_process_group)
create_fp16_flat_reuse_buffer = False
largest_partition_numel = []
max_partition_numel = 0
# create a flat CPU memory allocation for each param group
if self.offload_param:
self._create_param_groups_fp16_flat_cpu_memory()
# loop to deal with groups
for j, param_group in enumerate(self.optimizer.param_groups):
sub_groups = self._create_fp16_sub_groups(param_group['params'])
if self.verbose:
print_rank_0(
f'fp16 group {j} has {len(sub_groups)} subgroups', force=False)
flat_offset = 0
for sub_group in sub_groups:
i = len(self.fp16_groups)
# push this group to list before modify
self.fp16_groups.append(sub_group)
self.sub_group_to_group_id[i] = j
# comment out for zero_to_fp32 debug
# if torch.distributed.get_rank() == 0:
# for param in self.fp16_groups[i]:
# print(f"{debug_param2name_id_shape(param)} {param.ds_shape}")
# These are the list of the partitioned parameters
self.fp16_partitioned_groups.append(
[param.ds_tensor for param in self.fp16_groups[i]])
total_elements = sum(
[t.ds_numel for t in self.fp16_partitioned_groups[i]])
self.fp16_partitioned_groups_flat_numel.append(total_elements)
if total_elements > max_partition_numel:
largest_partition_numel = [
t.ds_numel for t in self.fp16_partitioned_groups[i]
]
max_partition_numel = total_elements
if self.verbose:
print_rank_0(
f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}"
)
# Record padding required to align group to world size (only applies to last rank)
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
padding = [p.padding_size() for p in self.fp16_groups[i]]
else:
padding = [0] * len(self.fp16_groups[i])
self.groups_padding.append(padding)
# not sure why apex was cloning the weights before flattening
# removing cloning here
if self.verbose:
report_memory_usage(
f"Before Flattening param subgroup {i}")
# all partitioned parameters remain in GPU during training
if not self.offload_param:
if self.verbose:
report_memory_usage(
f"Before moving param subgroup group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_partitioned_groups[i])
if self.verbose:
report_memory_usage(
f"After moving param subgroup {i} to CPU")
# create flat buffer in CPU and move to GPU
self.fp16_partitioned_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
1).cuda(torch.cuda.current_device()))
if self.verbose:
report_memory_usage(
f"After flattening and moving param subgroup {i} to GPU")
# all partitioned parameters are in CPU during training
else:
if self.verbose:
print_rank_0(
f"Params in nvme and cpu {self.params_in_nvme_and_cpu}")
# Flat buffer may not be available for parameters that reside in NVME
if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= \
self.param_groups_fp16_flat_cpu_memory[
j].numel():
fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[
j].narrow(0,
flat_offset,
total_elements)
if self.verbose:
print_rank_0(
f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elemets {flat_offset + total_elements}",
force=False)
# these parameters reside in NVME and
elif self.params_in_nvme_and_cpu:
fp16_partitioned_group_flat = None
if self.verbose:
print_rank_0(
f"No flat buffer for sub group {i} of {total_elements} elements",
force=False)
else:
assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please report_memory_usage()ms_in_cpu and params_in_nvme configs"
self.fp16_partitioned_groups_flat.append(
fp16_partitioned_group_flat)
flat_offset += total_elements
# move param to flat buffer for both param offload on/off
self._move_to_flat_buffer(self.fp16_groups[i],
self.fp16_partitioned_groups_flat[i],
avoid_copy=not self.offload_param)
if self.verbose:
report_memory_usage(f"After Flattening param group {i}")
# create a pinned memory to be used for swapping out params to NVME after optimizer step
if self.fp16_partitioned_groups_flat[-1] is None:
create_fp16_flat_reuse_buffer = True
if self.verbose:
report_memory_usage(f"After Flattening param subgroup {i}")
if create_fp16_flat_reuse_buffer:
assert len(
largest_partition_numel) > 0, f'Unexpected that largest partition is empty'
self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space(
largest_partition_numel)
def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id):
offset = 0
elements_in_sub_group = sum(
[t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]])
assert (flat_buffer.numel() == elements_in_sub_group)
for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]):
dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel)
if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
if self.verbose:
print_rank_0(
f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}"
)
param.nvme_swapper.swap_in([param], async_op=False)
dest.data.copy_(partitioned_param.data)
param.nvme_swapper.remove_partition_and_release_buffers([
param])
if self.verbose:
print_rank_0(f"Swapping in {param.ds_id} done")
else:
dest.data.copy_(partitioned_param.data)
offset += partitioned_param.ds_numel
def _create_next_swappable_fp32_groups(self):
reverse_order_indices = [
i for i in range(len(self.fp32_partitioned_groups_flat))
]
reverse_order_indices.reverse()
next_group = None
for i in reverse_order_indices:
self.next_swappable_fp32_partitioned_groups.append(next_group)
if self._swappable_optimizer_subgroup(i):
next_group = self.fp32_partitioned_groups_flat[i]
self.next_swappable_fp32_partitioned_groups.reverse()
def _get_sub_group_partitions(self, sub_group_id):
sub_group_partitions = []
for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]):
if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
swap_path = param.nvme_swapper.get_path(param, True)
sub_group_partitions.append((partitioned_param,
param.ds_tensor.ds_numel,
swap_path))
else:
sub_group_partitions.append((partitioned_param,
partitioned_param.ds_numel,
None))
return sub_group_partitions
def _create_fp32_partitions(self):
cpu_memory_usage = 0
cpu_memory_sub_groups = 0
nvme_memory_usage = 0
num_swappable_partitions = 0
num_swap_from_nvme_partitions = 0
num_swap_from_cpu_partitions = 0
swap_from_nvme_memory_usage = 0
swap_from_cpu_memory_usage = 0
GIGA_BYTES = (1024 ** 3)
swappable_fp32_tensors = []
swappable_fp16_src_tensors = []
nvme_fp16_partitions_info = []
nvme_fp16_num_elems = []
nvme_fp32_dest_tensors = []
fp32_element_size = torch.tensor(
[], dtype=torch.float32).element_size()
for i, tensor in enumerate(self.fp16_partitioned_groups_flat):
num_elements = self.fp16_partitioned_groups_flat_numel[i]
# a partition of the fp32 master weights that will be updated by this process
if self._swappable_optimizer_subgroup(i):
self.fp32_partitioned_groups_flat.append(torch.Tensor())
nvme_memory_usage += (fp32_element_size * num_elements)
num_swappable_partitions += 1
if self.params_in_nvme_and_cpu and tensor is None:
num_swap_from_nvme_partitions += 1
swap_from_nvme_memory_usage += (
fp32_element_size * num_elements)
if self.offload_optimizer_fast_init:
sub_group_partitions = self._get_sub_group_partitions(
i)
nvme_fp16_partitions_info.append(sub_group_partitions)
nvme_fp16_num_elems.append(num_elements)
nvme_fp32_dest_tensors.append(
self.fp32_partitioned_groups_flat[i])
else:
unpinned_fp32_buffer = torch.empty(num_elements,
device=self.device,
dtype=torch.float)
self._swap_in_sub_group_to_flat_buffer(
unpinned_fp32_buffer, i)
self.optimizer_swapper.initialize_parameters(
parameters=[self.fp32_partitioned_groups_flat[i]],
src_tensors=[unpinned_fp32_buffer])
else:
num_swap_from_cpu_partitions += 1
swap_from_cpu_memory_usage += (
fp32_element_size * num_elements)
swappable_fp32_tensors.append(
self.fp32_partitioned_groups_flat[i])
swappable_fp16_src_tensors.append(
self.fp16_partitioned_groups_flat[i])
else:
cpu_memory_usage += (fp32_element_size * num_elements)
cpu_memory_sub_groups += 1
if self.params_in_nvme_and_cpu and tensor is None:
unpinned_fp32_buffer = torch.empty(num_elements,
device=self.device,
dtype=torch.float)
self._swap_in_sub_group_to_flat_buffer(
unpinned_fp32_buffer, i)
self.fp32_partitioned_groups_flat.append(
unpinned_fp32_buffer)
else:
self.fp32_partitioned_groups_flat.append(
self.fp16_partitioned_groups_flat[i].to(
self.device).clone().float().detach())
self.fp32_partitioned_groups_flat[
i].requires_grad = True # keep this in case internal optimizer uses it
if len(swappable_fp32_tensors) > 0:
self.optimizer_swapper.initialize_parameters(
parameters=swappable_fp32_tensors,
src_tensors=swappable_fp16_src_tensors)
if len(nvme_fp32_dest_tensors) > 0:
fp16_pinned_buffers = self.fp16_groups[0][
0].nvme_swapper.reserve_available_buffers()
assert len(fp16_pinned_buffers) > 0
self.optimizer_swapper.initialize_from_swapped_fp16_params(
fp16_partitions_info=nvme_fp16_partitions_info,
fp16_num_elems=nvme_fp16_num_elems,
fp16_pinned_buffers=fp16_pinned_buffers,
fp32_parameters=nvme_fp32_dest_tensors)
self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers()
nvme_gigabytes = nvme_memory_usage / GIGA_BYTES
if self.verbose:
print_rank_0(
f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB',
force=False)
if self.params_in_nvme_and_cpu:
if self.verbose:
print_rank_0(
f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage / GIGA_BYTES:5.2f}GB',
force=False)
print_rank_0(
f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage / GIGA_BYTES:5.2f}GB',
force=False)
cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES
if self.verbose:
print_rank_0(
f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB',
force=False)
# Clear for on-the-fly population before the optimizer step
for param_group in self.optimizer.param_groups:
param_group['params'] = []
def _create_fp16_sub_groups(self, params_group):
params_group_numel = sum([param.partitioned_size()
for param in params_group])
sub_group_size = self.sub_group_size
if sub_group_size is None or sub_group_size >= params_group_numel:
return [params_group]
sub_groups = []
sub_group = []
local_sub_group_size = 0
for param in params_group:
sub_group.append(param)
local_sub_group_size += param.partitioned_size()
if local_sub_group_size >= sub_group_size or id(param) == id(
params_group[-1]):
sub_groups.append(sub_group)
sub_group = []
local_sub_group_size = 0
return sub_groups
# def reset_ds_tensor(self):
# for name, param in self.module.named_parameters(recurse=True):
# assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible"
# assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now"
# param.ds_tensor.data = param.data
def setup_zero_stage3_hooks(self):
self.hierarchy = 0
self._register_hooks_recursively(self.module)
# reset step at the beginning of forward
def _pre_forward_hook(module, *args):
self.param_coordinator.reset_step()
# reset step if in inference mode
def _end_of_forward_hook(module, *args):
if not torch._C.is_grad_enabled():
self.param_coordinator.reset_step()
# likely one of them should be enough but just to be safe
self.module.register_forward_hook(_end_of_forward_hook)
self.module.register_forward_pre_hook(_pre_forward_hook)
# Add top todule to stack trace
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(self.module)
def persistent_parameters(self):
persistent_params = []
total_persistent_parameters = 0
params_count = 0
for _, param in self.module.named_parameters(recurse=True):
if param.ds_numel < self.persistence_threshold:
params_count += 1
param.ds_persist = True
persistent_params.append(param)
total_persistent_parameters += param.ds_numel
if self.verbose:
print_rank_0(
f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
force=False)
return persistent_params
def _register_hooks_recursively(self, module, count=[0]):
my_count = count[0]
module.id = my_count
# print(f"{module.__class__} : {module.id}")
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
def _pre_forward_module_hook(module, *args):
self.pre_sub_module_forward_function(module)
def _post_forward_module_hook(module, input, output):
global FWD_MODULE_STACK
FWD_MODULE_STACK.pop()
if output is None:
output = []
elif not isinstance(output, (list, tuple)):
if torch.is_tensor(output):
output = [output]
else:
# print(f'got UNKNOWN type {type(output)}')
outputs = []
output = output if isinstance(
output, dict) else vars(output)
for name, val in output.items():
if not name.startswith('__') and torch.is_tensor(val):
outputs.append(val)
output = outputs
# print(f'convert output to {output}')
for item in filter(lambda item: is_zero_param(item), output):
if not any(id(item) in m._external_params for m in FWD_MODULE_STACK):
item.ds_active_sub_modules += 1
module_to_register = FWD_MODULE_STACK[-1]
if self.verbose:
print_rank_0(
f'Registering dangling parameter for module {module_to_register.__class__.__name__}.',
force=False)
register_external_parameter(module_to_register, item)
# It's possible that the parameter was already external to the completed module. If so, remove it the
# registration as it will be covered by the outer module instead.
if id(item) in module._external_params:
if self.verbose:
print_rank_0(
f' Unregistering nested dangling parameter from module {module.__class__.__name__}',
force=False)
unregister_external_parameter(module, item)
item.all_gather()
self.post_sub_module_forward_function(module)
def _pre_backward_module_hook(module, inputs, output):
def _run_before_backward_function(sub_module):
# some models (e.g. Albert) may run multiple forwards on the same layer in a loop
# before doing backwards, so each backward will need a pre-fetch - using reference
# counting to support this scenario
# print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
if sub_module.applied_pre_backward_ref_cnt > 0:
self.pre_sub_module_backward_function(sub_module)
sub_module.applied_pre_backward_ref_cnt -= 1
# print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
return _apply_to_tensors_only(module,
PreBackwardFunction,
_run_before_backward_function,
output)
# This is an alternate to doing _post_backward_module_hook
# it uses tensor.register_hook instead of using torch.autograd.Function
def _alternate_post_backward_module_hook(module, inputs):
module.ds_grads_remaining = 0
# print(f"Before Forward {module.__class__.__name__}")
def _run_after_backward_hook(*unused):
module.ds_grads_remaining = module.ds_grads_remaining - 1
if module.ds_grads_remaining == 0:
# print(f"After backward {module.__class__.__name__}")
self.post_sub_module_backward_function(module)
def _run_before_forward_function(input):
if input.requires_grad:
module.ds_grads_remaining += 1
return _apply_forward_and_backward_to_tensors_only(
module,
_run_before_forward_function,
_run_after_backward_hook,
inputs)
def _post_backward_module_hook(module, inputs):
module.ds_grads_remaining = 0
def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)
return _apply_to_tensors_only(module,
PostBackwardFunction,
_run_after_backward_function,
inputs)
# Pre forward hook
module.register_forward_pre_hook(_pre_forward_module_hook)
# Post forward hook
module.register_forward_hook(_post_forward_module_hook)
# Pre backward hook
module.register_forward_hook(_pre_backward_module_hook)
# post backward hook
module.register_forward_pre_hook(_post_backward_module_hook)
def pre_sub_module_forward_function(self, sub_module):
if self.verbose:
report_memory_usage(
f"Before sub module function {sub_module.__class__.__name__}")
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module)
self.param_coordinator.record_trace(sub_module)
self.param_coordinator.fetch_sub_module(sub_module)
if self.verbose:
report_memory_usage(
f"Before sub module function {sub_module.__class__.__name__} after fetch")
self.param_coordinator.prefetch_next_sub_modules(
sub_module,
numel=self.prefetch_elements,
nvme=self.params_in_nvme_and_cpu)
if self.verbose:
report_memory_usage(
f"Before sub module function {sub_module.__class__.__name__} after prefetch")
self.param_coordinator.increment_step(sub_module)
def post_sub_module_forward_function(self, sub_module):
if self.verbose:
report_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release")
self.param_coordinator.release_sub_module(sub_module)
if self.verbose:
report_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release")
def pre_sub_module_backward_function(self, sub_module):
self.param_coordinator.record_trace(sub_module)
self.param_coordinator.fetch_sub_module(sub_module)
self.param_coordinator.prefetch_next_sub_modules(sub_module,
numel=self.prefetch_elements)
self.param_coordinator.increment_step(sub_module)
def post_sub_module_backward_function(self, sub_module):
if self.verbose:
report_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release")
self.param_coordinator.release_sub_module(sub_module)
if self.verbose:
report_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release")
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
if not self.offload_optimizer and self.is_gradient_accumulation_boundary:
self.grads_in_partition = None
self.grads_in_partition_offset = 0
def _optimizer_step(self, sub_group_id):
param_group_id = self.sub_group_to_group_id[sub_group_id]
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
fp16_param = self.fp16_partitioned_groups_flat[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = [fp32_param]
self.optimizer.step()
self.optimizer.param_groups[param_group_id]['params'] = []
def _swappable_optimizer_subgroup(self, sub_group_id):
if not self.swap_optimizer:
return False
return self.optimizer_swapper.swappable_tensor(
None,
numel=self.fp16_partitioned_groups_flat_numel[sub_group_id])
def _partitioned_params_swap_out(self, i):
offset = 0
fp32_param = self.fp32_partitioned_groups_flat[i]
assert fp32_param is not None, \
f'fp32 parameters of sub_group {i} is None'
swap_fp16_params = []
swap_fp32_params = []
for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]):
src = fp32_param.narrow(0, offset, partitioned_param.ds_numel)
if partitioned_param.status == PartitionedParamStatus.AVAILABLE:
partitioned_param.data.copy_(src.data)
else:
swap_fp32_params.append(src)
swap_fp16_params.append(param)
offset += partitioned_param.ds_numel
if len(swap_fp16_params):
swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(
dst_fp16_params=swap_fp16_params,
src_fp32_params=swap_fp32_params)
def initialize_optimizer_states(self):
num_subgroups = len(self.fp16_groups)
largest_numel = max(
[sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups])
gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype
gradient_buffer = torch.zeros(int(largest_numel),
dtype=gradient_dtype,
device=self.device)
timers = self.timers
timer_names = set()
if self.swap_optimizer:
self.optimizer_swapper.init_timers()
INIT_OPTIMIZER_TIMER = 'init_optimizer_state'
timer_names.add(INIT_OPTIMIZER_TIMER)
self.start_timers([INIT_OPTIMIZER_TIMER])
for i, group in enumerate(self.fp16_groups):
swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(
i)
swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None
num_elements = int(self.fp16_partitioned_groups_flat_numel[i])
if self.verbose:
report_memory_usage(
f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}')
if swappable_optimizer_subgroup:
self._optimizer_states_and_gradient_swap_in(i, timer_names)
if self.offload_optimizer and not swappable_optimizer_subgroup:
subgroup_gradient_buffer = torch.zeros(num_elements,
dtype=gradient_dtype,
device=self.device)
if self.offload_optimizer_pin_memory:
subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory()
self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer
else:
self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(
0,
0,
num_elements)
self._optimizer_step(i)
if swappable_param_subgroup:
self._partitioned_params_swap_out(i)
if swappable_optimizer_subgroup:
self._optimizer_states_and_gradient_swap_out(i, timer_names)
if self.verbose:
report_memory_usage(
f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}')
self.stop_timers([INIT_OPTIMIZER_TIMER])
self.log_timers(timer_names)
if self.swap_optimizer:
self.optimizer_swapper.log_timers()
if not self.offload_optimizer:
for group in self.fp32_partitioned_groups_flat:
group.grad = None
# Reset steps
return
#########################################################################
#########################ZeRO Partition Gradients########################
#########################################################################
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None
def initialize_gradient_partitioning_data_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, param_group in enumerate(self.fp16_groups):
self.param_to_partition_ids[i] = {}
self.is_partition_reduced[i] = {}
self.total_grads_in_partition[i] = {}
self.remaining_grads_in_partition[i] = {}
self.is_grad_computed[i] = {}
self.grad_partition_insertion_offset[i] = {}
self.grad_start_offset[i] = {}
self.first_param_index_in_partition[i] = {}
for partition_id in range(total_partitions):
self.is_grad_computed[i][partition_id] = {}
self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
self.initialize_gradient_partition(
i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][
partition_id] = self.get_first_param_index(
i,
param_group,
partition_id)
def independent_gradient_partition_epilogue(self):
if self.verbose:
self.report_ipg_memory_usage(
f"In ipg_epilogue before reduce_ipg_grads", 0)
self.reduce_ipg_grads()
if self.verbose:
self.report_ipg_memory_usage(
f"In ipg_epilogue after reduce_ipg_grads", 0)
if self.overlap_comm:
self.reduction_stream.synchronize()
with torch.cuda.stream(self.reduction_stream):
self.partition_previous_reduced_grads()
# if dist.get_rank() == 0:
# print()("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
# in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad
# TODO: use a similar code path for both cpu_offload and non-cpu offload
if not self.offload_optimizer:
for i, sub_group in enumerate(self.fp16_groups):
self.averaged_gradients[i] = [
torch.zeros_like(param.ds_tensor) if param.grad is None else
param.grad.data.narrow(0,
0,
param.ds_tensor.numel())
for param in sub_group
]
# self.averaged_gradients[i] = self.get_flat_partition(
# self.fp16_groups[i],
# 0,
# self.fp32_partitioned_groups_flat[i].numel(),
# return_tensor_list=True)
self._release_ipg_buffers()
if self.verbose:
report_memory_usage(f"End ipg_epilogue")
# resets all partition to no reduced
# sets remianing grads to the total number of grads in each partition
# set is grad computed to false for all grads in partition
def reset_partition_gradient_structures(self):
total_partitions = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.fp16_groups):
for partition_id in range(total_partitions):
self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][
partition_id] = self.total_grads_in_partition[i][partition_id]
for param_id in self.is_grad_computed[i][partition_id]:
self.is_grad_computed[i][partition_id][param_id] = False
def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value):
if key in dictionary:
dictionary[key].append(value)
else:
dictionary[key] = [value]
def increment_value(dictionary, key):
if key in dictionary:
dictionary[key] += 1
else:
dictionary[key] = 1
partition_size = self.partition_size[i]
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for param in param_group:
param_size = param.numel()
param_id = self.get_param_id(param)
if (current_index >= start_index and current_index < end_index):
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][
param_id] = current_index - start_index
self.grad_start_offset[i][partition_id][param_id] = 0
elif start_index > current_index and start_index < (current_index +
param_size):
assert (
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i],
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
self.grad_start_offset[i][partition_id][param_id] = first_offset
current_index = current_index + param_size
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue()
self.zero_grad()
def create_reduce_and_remove_grad_hooks(self):
if self.verbose:
print_rank_0(f'[Begin] Create gradient reduction hooks')
self.grad_accs = []
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
if param.requires_grad:
# print_rank_0(f" Before all gather {param.device}, {param.shape}")
# The hook must be created in un-partitioned parameter
param.all_gather()
# print(f"After all gather {param.device}, {param.shape}")
def wrapper(param, i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(
param, i)
grad_acc.register_hook(
reduce_partition_and_remove_grads)
self.grad_accs.append(grad_acc)
# print(f"param grad fn {param.expand_as(param).grad_fn}")
wrapper(param, i)
# Partition the parameter after creating the hook
param.partition()
if self.verbose:
print_rank_0(f'[End] Create gradient reduction hooks')
def get_param_id(self, param):
unique_id = id(param)
return self.param_id[unique_id]
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (
100.0 * elem_count) // self.reduce_bucket_size
report_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}")
###############Idependent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
# print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True)
# Because the ipg bucket is initialized with a random place holder tensor, we must
# explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket >
# 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a
# garbage data and `self.average_tensor()` will crash because its params_to_reduce will be
# empty, while reduction_list will have that garbage data.
if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size:
if self.verbose:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
param.ds_numel)
self.reduce_ipg_grads()
if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index
if self.verbose:
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
param.ds_numel)
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
# keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
if param.ds_numel > self.reduce_bucket_size:
self.extra_large_param_to_reduce = param
elif self.contiguous_gradients:
# print_rank_0("before new grad tensor move")
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
0,
self.elements_in_ipg_bucket,
param.ds_numel)
# print_rank_0("after new grad tensor move")
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.elements_in_ipg_bucket += param.ds_numel
self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id))
if self.verbose:
self.report_ipg_memory_usage("End ipg_remove_grads", 0)
def gradient_reduction_w_predivide(self, tensor):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.postscale_gradients:
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.gradient_predivide_factor != dp_world_size:
tensor_to_allreduce.mul_(
self.gradient_predivide_factor / dp_world_size)
else:
tensor_to_allreduce.div_(dp_world_size)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def average_tensor(self, tensors, params_to_reduce):
with torch.cuda.stream(self.reduction_stream):
if not self.reduce_scatter:
for tensor in tensors:
self.gradient_reduction_w_predivide(tensor)
return
for tensor in tensors:
tensor.div_(dist.get_world_size(group=self.dp_process_group))
# reduction resulting with each rank only holding the gradient partition it owns
# This could either be a reduce scatter or a reduce op depending on how
# parameters are partitionied. The method is implemented by the
# DeepSpeed param extensions to the pytorch parameter, so its up to
# the extension to define what happens here
params_to_reduce[0].reduce_gradients_at_owner(
param_list=params_to_reduce,
hierarchy=self.param_coordinator.hierarchy)
def set_grad_positions(self):
for i, group in enumerate(self.fp16_groups):
current_offset = 0
for param in group:
param_id = self.get_param_id(param)
num_elements = param.ds_tensor.ds_numel
self.grad_position[param_id] = [
int(i),
int(current_offset),
int(num_elements)
]
# print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}")
current_offset += num_elements
def async_accumulate_grad_in_cpu_via_gpu(self, param, acc_grad_cpu_partition):
# copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
0,
0,
param.ds_tensor.ds_numel)
if self.micro_step_id > 0:
dest_buffer.copy_(
acc_grad_cpu_partition.view(-1), non_blocking=True)
param.grad.data.view(-1).add_(dest_buffer)
# at the boundary we will send 32bit directly
if not self.is_gradient_accumulation_boundary:
acc_grad_cpu_partition.data.copy_(param.grad.data.view(-1),
non_blocking=True)
def _constant_buffered_norm2(self, input, buffer_size=250000000):
norm = None
for part in input.view(-1).split(buffer_size):
if norm is None:
norm = part.data.double().norm(2) ** 2.0
else:
norm += part.data.double().norm(2) ** 2.0
return norm ** 0.5
def set_norm_for_param_grad_in_gpu(self, param):
param_id = self.get_param_id(param)
# self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2)
# Using a more memory efficient version
self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(
param.grad)
def update_overflow_tracker_for_param_grad(self, param):
# Credit to our user David Minn
if param.grad is not None:
if self.overlap_comm:
self.gpu_sum = self.gpu_sum + param.grad.data.float().sum()
elif self._has_inf_or_nan(param.grad.data):
self.local_overflow = True
def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor):
with torch.cuda.stream(self.copy_grad_stream):
param_id = self.get_param_id(param)
src_tensor = param.grad.view(-1).float()
# print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}")
fp32_grad_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None
def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
norm_type = 2.0
for p in params:
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_id = self.get_param_id(p)
if param_id in self.norm_for_param_grads.keys():
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item() ** 2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
def partition_previous_reduced_grads(self):
if not self.previous_reduced_grads:
return
if self.offload_optimizer:
allocate_grads_in_partition = self.grads_in_partition is None \
and self.gradient_accumulation_steps > 1
else:
allocate_grads_in_partition = self.grads_in_partition is None
if allocate_grads_in_partition:
self.grads_in_partition = []
for i, group in enumerate(self.fp16_groups):
total_size = 0
for param_in_partition in group:
total_size += param_in_partition.ds_tensor.ds_numel
if self.verbose:
report_memory_usage(
f"group {i} before creating {total_size} reduced gradients into partition")
if self.offload_param_pin_memory:
self.grads_in_partition.append(
torch.zeros(int(total_size),
dtype=self.dtype,
device=self.device).pin_memory())
else:
self.grads_in_partition.append(
torch.zeros(int(total_size),
dtype=self.dtype,
device=self.device))
if self.verbose:
report_memory_usage(
f"group {i} after creating {total_size} reduced gradients into partition")
if self.offload_optimizer:
offload_fp32_gradients = {}
offload_fp32_offsets = {}
with torch.cuda.stream(self.copy_grad_stream):
self.reduction_stream.synchronize()
for param in self.previous_reduced_grads:
[i,
dest_offset,
num_elements] = self.grad_position[self.get_param_id(param)]
if self.offload_optimizer:
param.partition_gradients(
partition_buffers=self.temp_grad_gpu_buffer)
# with torch.cuda.stream(self.copy_grad_stream):
# self.reduction_stream.synchronize()
if self.gradient_accumulation_steps > 1:
# The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
fp16_grad_tensor = self.grads_in_partition[i].narrow(
0,
dest_offset,
num_elements)
self.async_accumulate_grad_in_cpu_via_gpu(
param,
fp16_grad_tensor)
if self.is_gradient_accumulation_boundary:
self.set_norm_for_param_grad_in_gpu(param)
self.update_overflow_tracker_for_param_grad(param)
if self._swappable_optimizer_subgroup(i):
if not i in offload_fp32_gradients.keys():
offload_fp32_gradients[i] = []
offload_fp32_offsets[i] = []
offload_fp32_gradients[i].append(
param.grad.view(-1).float())
param.grad = None
offload_fp32_offsets[i].append(dest_offset)
else:
fp32_grad_tensor = self.fp32_partitioned_groups_flat[
i].grad.narrow(0,
dest_offset,
num_elements)
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(
param,
fp32_grad_tensor)
else:
# The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
fp16_grad_tensor = self.grads_in_partition[i].narrow(
0,
dest_offset,
num_elements)
param.partition_gradients(
partition_buffers=fp16_grad_tensor,
accumulate=True if self.micro_step_id > 0 else False)
if self.offload_optimizer and self.swap_optimizer:
for i in offload_fp32_gradients.keys():
self.optimizer_swapper.swap_out_gradients(
parameter=self.fp32_partitioned_groups_flat[i],
gradient_offsets=offload_fp32_offsets[i],
gradient_tensors=offload_fp32_gradients[i])
self.previous_reduced_grads = []
def reduce_ipg_grads(self, extra_param=None):
if self.overlap_comm:
self.reduction_stream.synchronize()
with torch.cuda.stream(self.reduction_stream):
self.partition_previous_reduced_grads()
params_to_reduce = [param for i, param,
param_id in self.params_in_ipg_bucket]
# print(f"Params in ipg bucket {self.params_in_ipg_bucket}")
# print(f"Reducing {[(debug_param2name_id_shape(param), param.grad) for param in params_to_reduce]}")
# exit(0)
if self.contiguous_gradients:
reduction_list = [self.ipg_buffer[self.ipg_index]]
if self.extra_large_param_to_reduce is not None:
reduction_list.append(self.extra_large_param_to_reduce.grad)
self.extra_large_param_to_reduce = None
self.average_tensor(reduction_list, params_to_reduce)
else:
self.buffered_reduce_fallback(
None,
self.grads_in_ipg_bucket,
elements_per_buffer=self.elements_in_ipg_bucket)
for _, param, param_id in self.params_in_ipg_bucket:
self.params_already_reduced[param_id] = True
self.previous_reduced_grads = params_to_reduce
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
#####################################################################
def reduce_ready_partitions_and_remove_grads(self, param, i):
# print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True)
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
def zero_reduced_gradients(self, partition_id, i):
def are_all_related_partitions_reduced(params_id):
for partition_id in self.param_to_partition_ids[i][params_id]:
if not self.is_partition_reduced[i][partition_id]:
return False
return True
for params_id in self.is_grad_computed[i][partition_id]:
if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None
def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = self.flatten(tensors)
def print_func():
print(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
self.sequential_execution(print_func, message)
def get_grads_to_reduce(self, i, partition_id):
def get_reducable_portion(key):
grad = self.param_dict[key].grad
total_elements = grad.numel()
start = self.grad_start_offset[i][partition_id][key]
num_elements = min(
total_elements - start,
self.partition_size[i] -
self.grad_partition_insertion_offset[i][partition_id][key])
if not pg_correctness_test:
if num_elements == total_elements:
return grad
else:
return grad.contiguous().view(-1).narrow(0,
int(start),
int(num_elements))
else:
if num_elements == total_elements:
return grad.clone()
else:
return grad.clone().contiguous().view(-1).narrow(
0,
int(start),
int(num_elements))
grads_to_reduce = []
for key in self.is_grad_computed[i][partition_id]:
grad = get_reducable_portion(key)
grads_to_reduce.append(grad)
return grads_to_reduce
def sequential_execution(self, function, message, group=None):
if group is None:
group = self.dp_process_group
if dist.get_rank(group=group) == 0:
print(message)
for id in range(dist.get_world_size(group=group)):
if id == dist.get_rank(group=group):
function()
dist.barrier(group=group)
def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
param = self.param_dict[param_id]
if param.grad is None:
param.grad = torch.zero_like(param)
######################Reduction Related Methods##############################
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
rank = None
tensor = self.flatten(bucket)
tensor_to_allreduce = tensor
if pg_correctness_test:
allreduce_always_fp32 = True
if allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
tensor_to_allreduce.div_(
dist.get_world_size(group=self.dp_process_group))
if rank is None:
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
global_rank = _get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank,
group=self.dp_process_group)
if allreduce_always_fp32 and tensor is not tensor_to_allreduce:
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
tensor.copy_(tensor_to_allreduce)
return tensor
# if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
with torch.cuda.stream(self.reduction_stream):
allreduced = self.allreduce_bucket(
small_bucket, rank=rank, log=log)
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self,
bucket,
numel_per_bucket=500000000,
rank=None,
log=None):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket, rank=rank, log=None)
small_bucket = []
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, rank=rank, log=log)
# allows using reduction of gradients instead of using all_reduce
def buffered_reduce_fallback(self,
rank,
grads,
elements_per_buffer=500000000,
log=None):
split_buckets = split_half_float_double(grads)
for i, bucket in enumerate(split_buckets):
self.allreduce_no_retain(bucket,
numel_per_bucket=elements_per_buffer,
rank=rank,
log=log)
#############################################################################
#############################################################################
#############################################################################
# views the tensor as multiple partitions and returns
# those partitions
def get_data_parallel_partitions(self, tensor):
partitions = []
dp = dist.get_world_size(group=self.dp_process_group)
dp_id = dist.get_rank(group=self.dp_process_group)
total_num_elements = tensor.numel()
base_size = total_num_elements // dp
remaining = total_num_elements % dp
start = 0
for id in range(dp):
partition_size = base_size
if id < remaining:
partition_size = partition_size + 1
partitions.append(tensor.narrow(0, start, partition_size))
start = start + partition_size
return partitions
def get_partition_info(self, tensor_list, partition_size, partition_id):
params_in_partition = []
params_not_in_partition = []
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for tensor in tensor_list:
tensor_size = tensor.numel()
if (current_index >= start_index and current_index < end_index):
params_in_partition.append(tensor)
elif start_index > current_index and start_index < (current_index +
tensor_size):
params_in_partition.append(tensor)
assert (
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
else:
params_not_in_partition.append(tensor)
current_index = current_index + tensor_size
return params_in_partition, params_not_in_partition, first_offset
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def _model_parallel_all_reduce(self, tensor, op):
""" Perform all reduce within model parallel group, if any.
"""
if self.model_parallel_group is None:
pass
else:
torch.distributed.all_reduce(tensor=tensor,
op=op,
group=self.model_parallel_group)
def clip_grad_norm(self, *args, **kwargs):
# dummy function to retain the same function interface
# as ColossalaiOptimizer for compatibility
pass
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from ``torch.nn.utils.clip_grad.clip_grad_norm_`` and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
# if dist.get_rank() == 0:
# print()(f"Total Norm begining {total_norm}")
for g, p in zip(gradients, params):
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_norm = g.data.double().norm(2)
total_norm += param_norm.item() ** 2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
# creates a flat fused tensor from the tensor list starting at the first_offset
# in the first tensor of the list. If there are not enough elements in the tensor
# list then the flat tensor will be padded with zeros
def get_flat_partition(self,
tensor_list,
first_offset,
partition_size,
return_tensor_list=False):
flat_tensor_list = []
current_size = 0
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
tensor.grad = torch.zeros_like(tensor)
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
# we need to offset to get to the right element
if i == 0 and first_offset > 0:
tensor_offset = first_offset
num_elements = num_elements - tensor_offset
# we dont need all elements of the tensor
if num_elements > (partition_size - current_size):
num_elements = partition_size - current_size
# we need a narrow view of the tensor based on the tensor offset and number of elements that
# we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)))
else:
flat_tensor_list.append(tensor)
current_size = current_size + num_elements
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size:
flat_tensor_list.append(
torch.zeros(int(partition_size - current_size),
dtype=tensor_list[0].dtype,
device=tensor_list[0].device))
if return_tensor_list:
return flat_tensor_list
return self.flatten(flat_tensor_list)
def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None
def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False
def log_timers(self, timer_names):
if self.timers is None:
return
self.timers.log(names=list(timer_names))
def start_timers(self, timer_names):
if self.timers is None:
return
for name in timer_names:
self.timers(name).start()
def stop_timers(self, timer_names):
if self.timers is None:
return
for name in timer_names:
self.timers(name).stop()
def _pre_step(self):
self.micro_step_id = INITIAL_MICRO_STEP_ID
if self.verbose:
print_rank_0(f"Inside Step function")
report_memory_usage(f"In step before checking overflow")
print_rank_0("Finished Tracing at Beginning of Step")
self.param_coordinator.hierarchy = 0
self.param_coordinator.finish_tracing(print_trace=True)
self.param_coordinator.reset_step()
if self.verbose:
print_rank_0("Finished Tracing at Beginning of Step")
def _get_norm_groups(self):
norm_groups = []
for i, group in enumerate(self.fp16_groups):
if self.offload_optimizer:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.fp16_groups[i]))
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.fp16_groups[i]))
return norm_groups
def _prepare_fp32_grad_for_sub_group(self, sub_group_id):
partition_id = dist.get_rank(group=self.dp_process_group)
single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to(
self.fp32_partitioned_groups_flat[sub_group_id].dtype)
assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
single_grad_partition.numel(
), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id,
partition_id)
self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition
self.zero_grad()
self.averaged_gradients[sub_group_id] = None
def _prepare_sub_group(self, sub_group_id, timer_names=set()):
if self.verbose:
report_memory_usage(
f'Before prepare optimizer sub group {sub_group_id}')
if self._swappable_optimizer_subgroup(sub_group_id):
self._optimizer_states_and_gradient_swap_in(
sub_group_id, timer_names)
elif not self.offload_optimizer:
self._prepare_fp32_grad_for_sub_group(sub_group_id)
if self.verbose:
report_memory_usage(
f'After prepare optimizer sub group {sub_group_id}')
def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()):
param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id]
fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id])
assert self._swappable_optimizer_subgroup(sub_group_id), \
f'Parameter {fp32_param_id} of numel={param_length} is not swappable'
OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state'
if self.verbose:
report_memory_usage(
f'pre-step Before swapping in optimizer tensors {sub_group_id}')
self.start_timers([OPTIMIZER_SWAP_IN_STATE])
self.optimizer_swapper.swap_in_optimizer_state(
parameter=self.fp32_partitioned_groups_flat[sub_group_id],
async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id])
self.stop_timers([OPTIMIZER_SWAP_IN_STATE])
timer_names.add(OPTIMIZER_SWAP_IN_STATE)
if self.verbose:
report_memory_usage(
f'pre-step After swapping in optimizer tensors {sub_group_id}')
def _release_sub_group(self, sub_group_id, timer_names=set()):
if self.verbose:
report_memory_usage(
f'Before release optimizer sub group {sub_group_id}')
# get rid of the fp32 gradients. Not needed anymore
if not self.offload_optimizer:
self.fp32_partitioned_groups_flat[sub_group_id].grad = None
if self._swappable_optimizer_subgroup(sub_group_id):
self._optimizer_states_and_gradient_swap_out(
sub_group_id, timer_names)
if self.verbose:
report_memory_usage(
f'After release optimizer sub group {sub_group_id}')
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(self, tensor_list, alignment):
num_elements = 0
for tens in tensor_list:
num_elements = num_elements + tens.numel()
remaining = num_elements % alignment
if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
return self.flatten(padded_tensor_list)
def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()):
param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id]
fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id])
assert self._swappable_optimizer_subgroup(sub_group_id), \
f'Parameter {fp32_param_id} of numel={param_length} is not swappable'
OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state'
if self.verbose:
report_memory_usage(
f'post-step Before swapping out optimizer tensors {sub_group_id}')
self.start_timers([OPTIMIZER_SWAP_OUT_STATE])
self.optimizer_swapper.swap_out_optimizer_state(
parameter=self.fp32_partitioned_groups_flat[sub_group_id],
async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is
not None)
self.stop_timers([OPTIMIZER_SWAP_OUT_STATE])
if self.verbose:
report_memory_usage(
f'post-step After swapping out optimizer tensors {sub_group_id}')
timer_names.add(OPTIMIZER_SWAP_OUT_STATE)
# get rid of the fp32 gradients. Not needed anymore
self.fp32_partitioned_groups_flat[sub_group_id].grad = None
def _unflatten_partitioned_parameters(self, sub_group_id):
updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id],
self.fp16_partitioned_groups[sub_group_id])
for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
partitioned_param.data = q.data
def _overflow_clean_up(self, prev_scale):
if self.verbose:
report_memory_usage('After overflow before clearing gradients')
self.zero_grad()
if self.offload_optimizer:
self.reset_cpu_buffers()
else:
self.averaged_gradients = {}
if self.verbose:
report_memory_usage('After overflow after clearing gradients')
if torch.distributed.get_rank() == 0:
print(
"[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
def _overflow_check_and_loss_scale_update(self):
# First compute norm for all group so we know if there is overflow
self.check_overflow()
# loss scaling related computation
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
self._overflow_clean_up(prev_scale)
return self.overflow
def _post_step(self, timer_names=set()):
if self.offload_optimizer:
self.reset_cpu_buffers()
# Gathering persisting parameters
if len(self.persistent_parameters) > 0:
self.persistent_parameters[0].all_gather(
self.persistent_parameters)
if self.swap_optimizer:
self.optimizer_swapper.log_timers()
self.log_timers(timer_names)
if self.verbose:
report_memory_usage('After zero_optimizer step')
print_rank_0(
f"------------------Finishing Step-----------------------")
def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id):
if self.fp16_partitioned_groups_flat[sub_group_id] is not None:
self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
self.fp32_partitioned_groups_flat[sub_group_id].data)
# unflatten fp16 parameter subgroup
self._unflatten_partitioned_parameters(sub_group_id)
else:
self._partitioned_params_swap_out(sub_group_id)
def allreduce_gradients(self):
self.overlapping_partition_gradients_reduce_epilogue()
def step(self, closure=None):
"""
Not supporting closure.
"""
self._pre_step()
# checks for overflow, adjust the loss scale accordingly
if self._overflow_check_and_loss_scale_update():
if self.swap_optimizer:
self.optimizer_swapper.log_timers()
return
norm_groups = self._get_norm_groups()
timer_names = set()
timer_names.add('optimizer_step')
self.start_timers(['optimizer_step'])
# update parameters one sub group at a time
for sub_group_id, group in enumerate(self.fp16_groups):
# prepare optimizer states, gradients and fp32 parameters for update
self._prepare_sub_group(sub_group_id, timer_names)
# scale the fp32 gradients
self.unscale_and_clip_grads(sub_group_id, norm_groups)
# apply the optimizer step on the sub group and copy fp32 parameters to fp16
self._optimizer_step(sub_group_id)
# put fp16 parameters in appropriate location
self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
# release memory or swap out optimizer states of fp32 parameters
self._release_sub_group(sub_group_id, timer_names)
self.stop_timers(['optimizer_step'])
self._post_step(timer_names)
return
def dump_pre_step_gradients(self, debug_fp32_grads):
# Dump gradient norms for debbuging
for i, _ in enumerate(self.fp16_groups):
if self.verbose:
print(
f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC')
for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]):
param_id = self.get_param_id(fp16_param)
fp16_grad_norm = self.debug_fp16_grads[i][param_id]
fp32_grad_norm = [float(t.data.float().norm(2))
for t in fp32_grad]
norm_list = [fp16_grad_norm, fp32_grad_norm]
if self.verbose:
print(f'Pre-Step Norms {i} {param_id} = {norm_list}')
def dump_post_step_gradients(self):
# Dump gradient norms for debbuging
for i, group in enumerate(self.fp16_groups):
if self.verbose:
print(
f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT')
unflat_fp16 = self.unflatten(
self.fp16_groups_flat[i], self.fp16_groups[i])
unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i],
self.fp16_groups[i])
for j, p in enumerate(self.fp16_groups[i]):
param_id = self.get_param_id(p)
param_norm = float(p.data.float().norm(2))
ds_norm = float(p.ds_tensor.data.float().norm(2))
unflat_norm = [
float(t.data.float().norm(2))
for t in [unflat_fp16[j],
unflat_fp32[j]]
]
norm_list = [param_norm, ds_norm] + unflat_norm
if self.verbose:
print(f'Post-Step Norms {i} {param_id} = {norm_list}')
def unscale_and_clip_grads(self, sub_group_id, norm_groups):
grad_groups_flat = [
self.fp32_partitioned_groups_flat[sub_group_id].grad]
total_norm = 0.0
for norm in norm_groups:
total_norm += norm ** 2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def _check_overflow(self, partition_gradients=True):
self.overflow = self.has_overflow(partition_gradients)
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params, is_grad_list=False):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False
def has_overflow_partitioned_grads_serial(self):
for i in range(len(self.fp16_groups)):
for j, grad in enumerate(self.averaged_gradients[i]):
if grad is not None and self._has_inf_or_nan(grad.data, j):
return True
return False
def has_overflow(self, partition_gradients=True):
if partition_gradients:
if self.overlap_comm:
self.local_overflow = self._has_inf_or_nan(self.gpu_sum)
self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda()
overflow = self.local_overflow if self.offload_optimizer else self.has_overflow_partitioned_grads_serial(
)
# overflow = self.has_overflow_partitioned_grads_serial()
overflow_gpu = torch.cuda.ByteTensor([overflow])
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.dp_process_group)
else:
params = []
for group in self.fp16_groups:
for param in group:
params.append(param)
overflow = self.has_overflow_serial(
params, is_grad_list=partition_gradients)
overflow_gpu = torch.cuda.ByteTensor([overflow])
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu,
op=torch.distributed.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, j=None):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
def backward(self, loss, retain_graph=False):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
self.micro_step_id += 1
if self.verbose:
print_rank_0(
f"Total fully available parameters {self.param_coordinator.total_available_parameter_numel}"
)
if self.swap_optimizer:
self.optimizer_swapper.pre_backward()
if self.verbose:
report_memory_usage(f"Before backward")
if self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(self.reduce_bucket_size,
dtype=self.dtype,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm:
buf_1 = torch.empty(self.reduce_bucket_size,
dtype=self.dtype,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_1)
self.ipg_index = 0
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
'''Partitioning Parameters that were not partitioned
Usually if parameters of modules whose input parameters do not require
grad computation do not trigger post call and will therefore will remain unpartitioned '''
self._partition_all_parameters()
if self.swap_optimizer:
self.optimizer_swapper.post_backward()
def _partition_all_parameters(self):
for name, param in self.module.named_parameters(recurse=True):
self.param_coordinator.release_and_reset_parameter(param)
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings):
# Remove paddings from flattened tensor
individual_tensors = self.unflatten(
padded_flattened_tensor, group_tensors)
lean_lengths = [t.numel() - pad for t,
pad in zip(group_tensors, paddings)]
lean_tensors = [t[:len]
for t, len in zip(individual_tensors, lean_lengths)]
# print()(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}')
return lean_tensors
# TODO REVISIT this for stage 3
def get_lean_optimizer_state(self):
# Return optimizer states after removing paddings.
# This method assumes that each param group contains a single flattened tensor.
optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
lean_state = {}
for key, value in self.optimizer.state[p].items():
if torch.is_tensor(value):
padded_lens = [t.numel()
for t in self.fp16_partitioned_groups[i]]
lean_state[key] = self._get_lean_tensors(
value,
self.fp16_partitioned_groups[i],
self.groups_padding[i])
lean_flat_len = sum([t.numel() for t in lean_state[key]])
else:
lean_state[key] = value
optimizer_groups_state.append(lean_state)
return optimizer_groups_state
def get_groups_without_padding(self, groups_with_padding):
# Return group tensor after removing paddings added for alignment to DP world size.
groups_without_padding = []
for i, group in enumerate(groups_with_padding):
lean_group = self._get_lean_tensors(group,
self.fp16_partitioned_groups[i],
self.groups_padding[i])
groups_without_padding.append(lean_group)
return groups_without_padding
def _set_fp32_optimizer_param_groups(self):
for sub_group_id, _ in enumerate(self.fp16_groups):
param_group_id = self.sub_group_to_group_id[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'].append(
self.fp32_partitioned_groups_flat[sub_group_id])
def _clear_fp32_optimizer_param_groups(self):
for param_group in self.optimizer.param_groups:
param_group['params'] = []
def _rigid_state_dict(self):
state_dict = {}
state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['partition_count'] = self.partition_count
self._set_fp32_optimizer_param_groups()
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat
self._clear_fp32_optimizer_param_groups()
return state_dict
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
if self.elastic_checkpoint:
raise NotImplementedError(
"ZeRO-3 does not yet support elastic checkpointing, please disable for now."
)
if self.swap_optimizer or self.params_in_nvme_and_cpu:
raise NotImplementedError(
"ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now."
)
return self._rigid_state_dict()
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
flat_local_partition = []
for i in range(len(self.fp32_partitioned_groups_flat)):
merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict]
flat_local_partition.append(
self._get_flattened_partition(merged_partitions))
for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat,
self.fp32_partitioned_groups_flat):
fp32_partition.data.copy_(fp16_partitions.data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
# Extract flattened partion for current rank from all partitions
def _get_flattened_partition(self, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
param_partitions = [[] for _ in range(len(all_partition_states[0]))]
for i, partition in enumerate(all_partition_states):
for j, param in enumerate(partition):
param_partitions[j].append(param)
local_state_partitions = []
for param_index, param_slices in enumerate(param_partitions):
flattened_merged_tensor = self.flatten_dense_tensors_aligned(
param_slices,
alignment)
new_partitions = self.get_data_parallel_partitions(
flattened_merged_tensor)
local_state_partitions.append(new_partitions[partition_id])
if torch.is_tensor(local_state_partitions[0]):
return self.flatten_dense_tensors_aligned(local_state_partitions, alignment)
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
return local_state_partitions[0]
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, all_state_dict):
base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)):
partition_states = {}
all_partition_group_states = [
sd['base_optimizer_state'][i] for sd in all_state_dict
]
for key in all_partition_group_states[0].keys():
all_partition_states = [
all_states[key] for all_states in all_partition_group_states
]
partition_states[key] = self._get_flattened_partition(
all_partition_states)
base_optimizer_group_states.append(partition_states)
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
if torch.is_tensor(self.optimizer.state[p][key]):
self.optimizer.state[p][key].data.copy_(saved.data)
else:
self.optimizer.state[p][key] = saved
def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
if load_optimizer_states:
self._set_fp32_optimizer_param_groups()
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
self._clear_fp32_optimizer_param_groups()
# restore fp32 partitions
for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']):
curr_param.data.copy_(saved_param.data)
# restore fp16 partitions from fp32
for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
fp16_param = self.fp16_partitioned_groups_flat[sub_group_id]
fp16_param.data.copy_(fp32_param.data)
# update fp16 unflattened params
for sub_group_id in range(len(self.fp16_partitioned_groups_flat)):
updated_params = self.unflatten(
self.fp16_partitioned_groups_flat[sub_group_id],
self.fp16_partitioned_groups[sub_group_id])
for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params):
partitioned_param.data = q.data
# TODO: Support different/changing load/save DP degree.
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
r"""Loading a ZeRO checkpoint
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
Note that the number of saved partitions may differ from number of loading partitions to support
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
if self.elastic_checkpoint:
raise NotImplementedError(
"ZeRO-3 does not yet support elastic checkpointing, please disable for now."
)
if self.swap_optimizer or self.params_in_nvme_and_cpu:
raise NotImplementedError(
"ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now."
)
self._rigid_load_state_dict(
state_dict_list[dist.get_rank(group=self.dp_process_group)],
load_optimizer_states=load_optimizer_states)
if len(self.persistent_parameters) > 0:
self.persistent_parameters[0].partition(self.persistent_parameters)
self.persistent_parameters[0].all_gather(
self.persistent_parameters)
def save_checkpoint_prologue(self):
self._partition_all_parameters()
def save_checkpoint_epilogue(self):
if len(self.persistent_parameters) > 0:
self.persistent_parameters[0].all_gather(
self.persistent_parameters)
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
print(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
def estimate_zero3_model_states_mem_needs(total_params,
largest_layer_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
cpu_offload_params=True,
zero_init=True,
additional_buffer_factor=1.5):
total_gpus = num_nodes * num_gpus_per_node
gpus_factor = 1 / num_nodes
largest_layer_memory = (4 * largest_layer_params)
if cpu_offload:
if cpu_offload_params:
gpu_mem = largest_layer_memory
if zero_init:
cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor
else:
cpu_mem = total_params * max(4 * num_gpus_per_node,
18 * gpus_factor) * additional_buffer_factor
else:
gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus)
if zero_init:
cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor
else:
cpu_mem = total_params * max(4 * num_gpus_per_node,
16 * gpus_factor) * additional_buffer_factor
else:
gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus)
if zero_init:
cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor
else:
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
return int(cpu_mem), int(gpu_mem), largest_layer_memory
def model_to_params(model):
# shared params calculated only once
total_params = sum(
dict((p.data_ptr(),
p.numel()) for p in model.parameters()).values())
largest_layer_params = 0
for m in model.modules():
# assuming no shared params within a single layer
layer_params = sum(p.numel() for p in m.parameters(recurse=False))
largest_layer_params = max(largest_layer_params, layer_params)
return total_params, largest_layer_params
def estimate_zero3_model_states_mem_needs_all_live(model,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients
for a given ``model`` and hardware setup.
If you have an actual model object, use this function and everything will be derived
automatically.
If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass
the ``total_params`` and ``largest_layer_params`` explicitly.
Args:
- ``model``: ``nn.Module`` object
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):
"""
total_params, largest_layer_params = model_to_params(model)
estimate_zero3_model_states_mem_needs_all_cold(
total_params=total_params,
largest_layer_params=largest_layer_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
additional_buffer_factor=additional_buffer_factor)
def estimate_zero3_model_states_mem_needs_all_cold(total_params,
largest_layer_params,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients
for a given ``model`` and hardware setup.
If it's a hypothetical model, use this function where you have to pass
the ``total_params`` and ``largest_layer_params`` explicitly.
If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything
will be derived automatically.
Args:
- ``total_params``: total model params
- ``largest_layer_params``: largest layer's params
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):
"""
def format_options(cpu_offload, cpu_offload_params, zero_init):
enabled = []
enabled.append(f"cpu_offload={1 if cpu_offload else 0}")
enabled.append(f"cpu_offload_params={1 if cpu_offload_params else 0}")
enabled.append(f"zero_init={1 if zero_init else 0}")
return ", ".join(enabled)
nodes_str = "nodes" if num_nodes > 1 else "node"
gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
print(
"Estimated memory needed for params, optim states and gradients for a:\n"
f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
f"SW: Model with {int(total_params / 1e6)}M total params, {int(largest_layer_params / 1e6)}M largest layer params."
)
print(" per CPU | per GPU | Options")
for cpu_offload in [True, False]:
for cpu_offload_params in [True, False]:
if not cpu_offload and cpu_offload_params:
continue
for zero_init in [True, False]:
cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs(
total_params=total_params,
largest_layer_params=largest_layer_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
cpu_offload=cpu_offload,
cpu_offload_params=cpu_offload_params,
zero_init=zero_init,
additional_buffer_factor=additional_buffer_factor
)
options_str = format_options(cpu_offload=cpu_offload,
cpu_offload_params=cpu_offload_params,
zero_init=zero_init)
print(
f" {cpu_mem / 2 ** 30:7.2f}GB | {gpu_mem / 2 ** 30:6.2f}GB | {options_str}")
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
import operator as op
from functools import partial, reduce
from typing import List
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from colossalai.zero.sharded_model import ShardedModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
class Enumerator:
def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None:
self.arg_names = arg_names
self.enums = Enumerator.all_enumerate(arg_values)
def __len__(self):
return len(self.enums)
def __getitem__(self, idx):
return {name: self.enums[idx][i] for i, name in enumerate(self.arg_names)}
@staticmethod
def all_enumerate(args: List[tuple]):
num_states = reduce(op.mul, map(lambda xs: len(xs), args))
idxs = [0] * len(args)
states = []
for _ in range(num_states):
states.append(tuple(args[j][idx] for j, idx in enumerate(idxs)))
if len(states) == num_states:
break
i = 0
while idxs[i] + 1 == len(args[i]):
idxs[i] = 0
i += 1
idxs[i] += 1
return states
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)
return module
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
clip_grad(model, norm_type)
optimizer.step()
def clip_grad(model, norm_type):
if isinstance(model, DDP):
clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
else:
clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def check_grads(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(4)
if rank >= len(chunks):
continue
grad = chunks[rank]
if zero_p.zero_shard_padding > 0:
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_shard_padding = zero_p.zero_shard_padding
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(4)
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_shard_padding > 0:
zero_p = zero_p[:-zero_shard_padding]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
ddp_model = DDP(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type)
check_grads(ddp_model, zero_model)
check_params(ddp_model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type)
run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type)
check_grads(ddp_model, zero_model, loose=True)
check_params(ddp_model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
args = ['checkpoint', 'fp16', 'offload', 'norm_type']
arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))]
arg_enumerator = Enumerator(args, arg_values)
for kwargs in arg_enumerator:
if dist.get_rank() == 0:
print(kwargs)
check_config(**kwargs)
check_config()
@ pytest.mark.dist
def test_zero_clip_grad():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_clip_grad()
from functools import partial
from operator import imod
from colossalai.utils import checkpoint
import torch.nn as nn
import torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
LOGGER = get_dist_logger()
CONFIG = dict(
fp16=dict(
mode=None,
),
zero=dict(
level=3,
verbose=False,
offload_optimizer_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
fast_init=False
),
offload_param_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9
)
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
)
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)
return module
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def check_grads(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
assert p.grad.dtype == zero_grad.dtype
assert allclose(p.grad, zero_grad, loose=loose)
LOGGER.info(torch.sum(p.grad-zero_grad))
def check_params(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.clone().to(p.device)
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
from operator import mod
from pyexpat import model
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads
def run_fwd_bwd(model, x, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
model = Net(checkpoint=True).cuda()
zero_model = copy.deepcopy(model)
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
for _ in range(2):
x = torch.rand(2, 5).cuda()
run_fwd_bwd(zero_model, x, False)
run_fwd_bwd(model, x, False)
check_grads(model, zero_model)
@pytest.mark.dist
def test_shard_model_v2():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_model_v2()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from asyncio.log import logger
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_param import ShardParam
from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers
from tests.test_zero_data_parallel.common import Net, CONFIG
def run_shard_param_check(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
logger = get_dist_logger()
model = Net()
# add an attribute as ca_attr to hijack the access to param.data
for _, param in model.named_parameters():
numel_ref = (param.numel() + world_size - 1) // world_size
param.ca_attr = ShardParam(param)
param.ca_attr.shard()
param_data = param.ca_attr.payload(torch.device('cpu'))
logger.info(f'shard {param_data.shape} {param_data}', ranks = [1])
assert(numel_ref == param_data.numel())
for _, param in model.named_parameters():
param.ca_attr.gather()
param_data = param.ca_attr.payload(torch.device('cpu'))
logger.info(f'gather {param_data.shape} {param_data}', ranks = [1])
disable_existing_loggers([logger])
@pytest.mark.dist
def test_run_shard_shape():
world_size = 2
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_run_shard_shape()
\ No newline at end of file
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, free_port
from colossalai.zero.sharded_model import ShardedModel
from common import Net, check_grads, check_params, check_params
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)
return module
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def run_step(model, optimizer, x, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
optimizer.step()
def decode_booleans(intval, bits):
res = []
for bit in range(bits):
mask = 1 << bit
res.append((intval & mask) == mask)
return res
def check_config(checkpoint=False, fp16=False, offload=False):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(model, optimizer, x, enable_autocast=fp16)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
check_grads(model, zero_model)
check_params(model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(model, optimizer, x, enable_autocast=False)
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
check_grads(model, zero_model, loose=True)
check_params(model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
args = ['checkpoint', 'fp16', 'offload']
def pack_args(i):
booleans = decode_booleans(i, len(args))
return {arg: booleans[idx] for idx, arg in enumerate(args)}
for j in range(2 ** len(args)):
kwargs = pack_args(j)
print(kwargs)
check_config(**kwargs)
@pytest.mark.dist
def test_zero_level_3():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_3()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, free_port
from colossalai.zero.sharded_model import ShardedModel
from torch.nn.parallel import DistributedDataParallel as DDP
from common import Net, allclose
def run_step(model, optimizer, x, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
optimizer.step()
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(4)
if rank >= len(chunks):
continue
grad = chunks[rank]
if zero_p.zero_shard_padding > 0:
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_shard_padding = zero_p.zero_shard_padding
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(4)
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_shard_padding > 0:
zero_p = zero_p[:-zero_shard_padding]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def decode_booleans(intval, bits):
res = []
for bit in range(bits):
mask = 1 << bit
res.append((intval & mask) == mask)
return res
def check_config(checkpoint=False, fp16=False, offload=False):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
ddp_model = DDP(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=fp16)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
check_grads_padding(ddp_model, zero_model)
check_params_padding(ddp_model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=False)
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
check_grads_padding(ddp_model, zero_model, loose=True)
check_params_padding(ddp_model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
args = ['checkpoint', 'fp16', 'offload']
def pack_args(i):
booleans = decode_booleans(i, len(args))
return {arg: booleans[idx] for idx, arg in enumerate(args)}
for j in range(2 ** len(args)):
kwargs = pack_args(j)
if dist.get_rank() == 0:
print(kwargs)
check_config(**kwargs)
@pytest.mark.dist
def test_zero_level_3():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_3()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_dataloader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
BATCH_SIZE = 16
IMG_SIZE = 224
CONFIG = dict(
fp16=dict(
mode=None,
),
zero=dict(
level=2,
cpu_offload=True,
verbose=False,
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
)
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
# build model
model = resnet18(num_classes=10)
# build dataloader# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
# build optimizer and loss
# optimizer = build_optimizer(global_context.config.optimizer, model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
# train
model.train()
for idx, (data, label) in enumerate(train_dataloader):
engine.zero_grad()
data = data.cuda()
label = label.cuda()
output = engine(data)
loss = engine.criterion(output, label)
engine.backward(loss)
engine.step()
break
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
def test_zero_level_2():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_2()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_dataloader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
BATCH_SIZE = 16
IMG_SIZE = 224
CONFIG = dict(
fp16=dict(
mode=None,
),
zero=dict(
level=3,
verbose=False,
offload_optimizer_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
fast_init=False
),
offload_param_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9
)
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
)
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
# build model
model = resnet18(num_classes=10)
# build dataloader# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
# build optimizer and loss
# optimizer = build_optimizer(global_context.config.optimizer, model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
# train
model.train()
for idx, (data, label) in enumerate(train_dataloader):
engine.zero_grad()
data = data.cuda()
label = label.cuda()
output = engine(data)
loss = engine.criterion(output, label)
engine.backward(loss)
engine.step()
break
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
def test_zero_level_3():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_3()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from colossalai.utils import free_port
from common import CONFIG
def run_shard_shape_check(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
model = torch.nn.Linear(2, 4 * world_size)
gpc.init_parallel_groups()
Zero3ParameterManager(module=model, process_group=gpc.get_group(ParallelMode.DATA), offload_config=CONFIG.get('offload_param_config'))
assert(model.weight.numel() == 4 * 2)
assert(model.bias.numel() == 4)
@pytest.mark.dist
def test_run_shard_shape():
world_size = 2
run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_run_shard_shape()
......@@ -88,6 +88,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
def test_2d_vit_zero_level_2():
world_size = 8
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port())
......
......@@ -88,7 +88,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
def test_3d_vit_zero_level_3():
world_size = 8
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment