Unverified Commit 5ffb22d0 authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #1401 from timmoon10/dist-adam-zero

ZeRO-2 support in DistributedFusedAdam
parents 265b451d 846f7f8a
import collections
import contextlib
import enum
import importlib
import inspect
import math import math
import threading
import torch import torch
import importlib
import amp_C import amp_C
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from torch.distributed.distributed_c10d import _get_default_group
import torch.distributed.distributed_c10d as c10d
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
"""AdamW optimizer with ZeRO algorithm.
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_. This implements the ZeRO-2 algorithm, which distributes the
optimizer state and gradients between parallel processes. In
particular, the parameters are flattened, grouped into fixed-size
buckets, and the optimizer state for each bucket is sharded over
the parallel processes. Options are provided to overlap the
gradient synchronization with the backward pass compute.
Adam was proposed in `Adam: A Method for Stochastic
Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_,
and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion
Parameter Models`_.
Arguments: Arguments:
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts
parameter groups. defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3) lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing betas (Tuple[float, float], optional): coefficients used for
running averages of gradient and its square. (default: (0.9, 0.999)) computing running averages of gradient and its square.
(default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8) numerical stability. (default: 1e-8)
eps_inside_sqrt (boolean, optional): in the 'update parameters' step, weight_decay (float, optional): weight decay (L2 penalty)
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
step_supports_amp_scaling(boolean, optional): whether to use customized
gradient unscaling logic (default: True)
num_process_groups (integer, optional): number of process groups in
the app (default: 1)
current_process_group (object, optional): the process group to work on
(default: None)
process_group_id (integer, optional): process group id (default: 0)
process_group_size (integer, optional): size of process group
(default: 0) (default: 0)
clip_grad_norm (boolean, optional): whether to handle gradient clipping amsgrad (boolean, optional): whether to use the AMSGrad
variant of this algorithm from the paper
`On the Convergence of Adam and Beyond`_ (default: False).
This is not yet supported.
dtype (torch.dtype, optional): datatype for optimizer state
(default: torch.float32)
grad_sync_dtype (torch.dtype, optional): datatype for gradient
synchronization (default: same as dtype)
param_sync_dtype (torch.dtype, optional): datatype for
parameter synchronization (default: same as
grad_sync_dtype)
device (torch.device, optional): device for optimizer state
(default: cuda). Currently only supports GPU.
process_group (torch.distributed.ProcessGroup, optional):
parallel processes participating in optimizer (default:
default group in torch.distributed). This group is
interpreted as a 2D grid with dimensions
distributed_size x redundant_size.
distributed_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to distribute optimizer
state over (default: same as process_group)
redundant_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to replicate optimizer state
over (default: group only containing calling process)
model_parallel (bool, optional): whether model parallelism is
used (default: False)
model_parallel_rank (int, optional): rank in model-parallel
process group (default: 0)
average_grad_sync (bool, optional): whether to use average
reduction for gradient synchronization rather than sum
(default: True) (default: True)
model_parallel (boolean, optional): whether model parallelism is used overlap_grad_sync(boolean, optional): whether to overlap
(default: False) gradient synchronization with backward pass compute
(default: True)
bucket_cap_mb (float, optional): bucket size in megabytes
(default: 15)
pipeline_size (int, optional): number of buckets to
synchronize simultaneously (default: 2)
fused_grad_copy (bool, optional): whether to used fused kernel
to fill bucket with gradients (default: False). Requires
all parameters to have the same data type.
max_grad_norm (float, optional): maximum L2 norm for gradient
clipping (default: disabled)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
.. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models:
https://arxiv.org/abs/1910.02054
""" """
def __init__(self, params, class GradientStatus(enum.Enum):
lr=1e-3, bias_correction=True, betas=(0.9, 0.999), """Status of gradients within a bucket"""
eps=1e-8, eps_inside_sqrt=False, # Gradients are ready to use
weight_decay=0., max_grad_norm=0., READY = enum.auto()
amsgrad=False, flat_mt=False, # Bucket is partially filled with unreduced gradients
overlap_reductions=True, PARTIALLY_FILLED = enum.auto()
compute_L2_grad_norm=False, # Bucket is fully filled with unreduced gradients
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, FULLY_FILLED = enum.auto()
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, # Asynchronous reduction is in progress
predivide=True, e5m2_allgather=False, SYNCING = enum.auto()
do_not_flatten_model=False,
step_supports_amp_scaling=True, def __init__(self,
num_process_groups=1, params,
current_process_group=None, lr=1e-3,
process_group_id=0, bias_correction=True,
process_group_size=0, betas=(0.9, 0.999),
clip_grad_norm=True, eps=1e-8,
model_parallel=False): weight_decay=0.,
amsgrad=False,
dtype=torch.float32,
grad_sync_dtype=None,
param_sync_dtype=None,
device='cuda',
process_group=None,
distributed_process_group=None,
redundant_process_group=None,
model_parallel=False,
model_parallel_rank=0,
average_grad_sync=True,
overlap_grad_sync=True,
bucket_cap_mb=15,
pipeline_size=2,
fused_grad_copy=False,
max_grad_norm=0.,
):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay)
super(DistributedFusedAdam, self).__init__(params, defaults)
# Adam options
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
# Datatype options
if grad_sync_dtype is None:
grad_sync_dtype = dtype
if param_sync_dtype is None:
param_sync_dtype = grad_sync_dtype
valid_dtypes = [
(torch.float32, torch.float16, torch.float16),
(torch.float32, torch.float32, torch.float32),
]
if (dtype, grad_sync_dtype, param_sync_dtype) not in valid_dtypes:
raise RuntimeError(
'Invalid dtypes for DistributedFusedAdam '
f'(dtype={dtype}, '
f'grad_sync_dtype={grad_sync_dtype}, '
f'param_sync_dtype={param_sync_dtype}))')
if device != 'cuda':
raise RuntimeError('DistributedFusedAdam only supports GPU')
self.dtype = dtype
self.grad_sync_dtype = grad_sync_dtype
self.param_sync_dtype = param_sync_dtype
self.device = device
# Process groups
self.world_process_group = (
_get_default_group()
if process_group is None
else process_group
)
self.distributed_process_group = (
self.world_process_group
if distributed_process_group is None
else distributed_process_group
)
self.redundant_process_group = redundant_process_group
self.world_size = torch.distributed.get_world_size(self.world_process_group)
self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group)
self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group)
self.redundant_size = (
1
if self.redundant_process_group is None
else torch.distributed.get_world_size(self.redundant_process_group)
)
if (self.world_size != self.distributed_size * self.redundant_size):
raise RuntimeError(
'Invalid process group configuration '
f'(world process group size = {self.world_size}, '
f'distributed process group size = {self.distributed_size}, '
f'redundant process group size = {self.redundant_size})'
)
self.model_parallel = model_parallel
self.model_parallel_rank = model_parallel_rank
# Grad sync options
if fused_grad_copy:
_params = list(self.parameters())
if (_params
and any(p.dtype != self.grad_sync_dtype for p in _params)
and any(p.device != self.device for p in _params)):
raise RuntimeError(
'Attempted to use fused gradient copy in DistributedFusedAdam, '
'but parameters do not all have expected '
f'dtype ({self.grad_sync_dtype}) and device ({self.device})'
)
self.average_grad_sync = average_grad_sync
self.overlap_grad_sync = overlap_grad_sync
self.pipeline_size = pipeline_size
self.fused_grad_copy = fused_grad_copy
# Grad clipping options
self.max_grad_norm = max_grad_norm
# Determine bucket sizes
dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8
self.alignment = 128 // dtype_size
bucket_size = 1024*1024*bucket_cap_mb / dtype_size
shard_size = bucket_size / self.distributed_size
shard_size = (int(shard_size) // self.alignment) * self.alignment
shard_size = max(shard_size, self.alignment)
bucket_size = shard_size * self.distributed_size
self.bucket_size = bucket_size
self.shard_size = shard_size
# Load CUDA kernels
global fused_adam_cuda, distributed_adam_cuda global fused_adam_cuda, distributed_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_adam_cuda = importlib.import_module("distributed_adam_cuda") distributed_adam_cuda = importlib.import_module("distributed_adam_cuda")
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
if amsgrad: # Optimizer state
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.') self.state['buckets'] = []
self.state['step'] = 0
defaults = dict(lr=lr, bias_correction=bias_correction, # Objects for gradient synchronization
betas=betas, eps=eps, weight_decay=weight_decay, self._grads_generated = set()
max_grad_norm=max_grad_norm) self._grads_to_copy = []
super(DistributedFusedAdam, self).__init__(params, defaults) self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)]
# Check if collectives have no_copy option
self._reduce_scatter_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
)
self._all_gather_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
)
# Attach hooks for gradient synchronization
self._register_post_backward_hooks()
def _register_post_backward_hooks(self):
"""Attach hooks for gradient synchronization
# Misc Optimizer state for parameters are initialized lazily as they
self.eps_mode = 0 if eps_inside_sqrt else 1 are encountered in the backward pass.
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False """
self._step_supports_amp_scaling = step_supports_amp_scaling self._num_grads = 0
self._last_step = False self._lock = threading.Lock()
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None
self._flat_mt = flat_mt
self._init_done = False
self._resume_from_checkpoint = False
self._step = 0
# Process group related
self._clip_grad_norm = clip_grad_norm
self._model_parallel = model_parallel
self._num_process_groups = num_process_groups
self._current_process_group = current_process_group if current_process_group is not None else c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
self._process_group_id = process_group_id
self._process_group_size = torch.cuda.device_count() if process_group_size <= 0 else process_group_size
self._world_size = self._process_group_size # world: the current process group
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._num_groups = self._world_size // self._group_size
self._global_rank = torch.distributed.get_rank()
self._world_rank = self._global_rank // self._num_process_groups
self._group_rank = self._world_rank % self._group_size
#print("world_size:", self._world_size, ", group_size:", self._group_size, ", num_groups:", self._num_groups, ", global_rank:", self._global_rank, ", world_rank:", self._world_rank, ", group_rank:", self._group_rank)
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
# Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
def _first_step_init(self):
p_offset = 0
p_i = 0
self._model_params = []
self._grads_info = []
self._grad_accs = [] self._grad_accs = []
self._group_properties = [] for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']):
torch.distributed.broadcast(
param,
src=0,
group=self.world_process_group,
)
if param.requires_grad:
def wrapper(p, p_group_id, p_id):
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
def reduction_hook(*unused):
with self._lock:
if 'fragments' not in self.state[p]:
self._init_param_state(p, p_group_id, p_id)
if self.overlap_grad_sync:
self._start_grad_copy(p)
self._try_start_bucket_grad_sync()
grad_acc.register_hook(reduction_hook)
self._grad_accs.append(grad_acc)
wrapper(param, param_group_id, param_id)
self._num_grads += 1
def _init_param_state(
self,
param,
param_group_id,
param_id,
):
"""Initialize optimizer state for a parameter"""
# Make sure there is at least one bucket
if not self.state['buckets']:
self._add_bucket()
# Split parameter values into fragments
# Note: Each fragment resides within a bucket
param_start = 0
param_size = param.numel()
self.state[param]['fragments'] = []
while param_start < param_size:
# Get current bucket
if not self.state['buckets']:
self._add_bucket()
bucket_id = len(self.state['buckets']) - 1
bucket = self.state['buckets'][bucket_id]
fragment_id = len(bucket['fragments'])
# Determine fragment position within bucket
if fragment_id == 0:
bucket_start = 0
else:
bucket_start = bucket['fragments'][-1]['bucket_range'][1]
bucket_start = (
(bucket_start + self.alignment - 1)
// self.alignment
* self.alignment
) # Pad until fragment is aligned
fragment_size = min(param_size-param_start, self.bucket_size-bucket_start)
param_end = param_start + fragment_size
bucket_end = bucket_start + fragment_size
# Create new bucket if current one is full
if fragment_size <= 0:
self._add_bucket()
continue
# Fragment position within local shard
shard_id = self.distributed_rank
shard_start = bucket_start - self.shard_size*shard_id
shard_end = bucket_end - self.shard_size*shard_id
shard_start = min(max(shard_start, 0), self.shard_size)
shard_end = min(max(shard_end, 0), self.shard_size)
in_local_shard = shard_start < shard_end
if in_local_shard:
shard_bucket_start = shard_start + self.shard_size*shard_id
shard_bucket_end = shard_bucket_start + shard_end - shard_start
shard_param_start = shard_bucket_start - bucket_start + param_start
shard_param_end = shard_param_start + shard_end - shard_start
else:
shard_bucket_start, shard_bucket_end = None, None
shard_param_start, shard_param_end = None, None
# Record fragment info
fragment = {
# Parameter group index
'param_group_id': param_group_id,
# Parameter index within parameter group
'param_id': param_id,
# Bucket index
'bucket_id': bucket_id,
# Range within flattened parameter buffer
'param_range': (param_start,param_end),
# Range within bucket
'bucket_range': (bucket_start,bucket_end),
# Whether fragment is in local shard of bucket
'in_local_shard': in_local_shard,
# Range within local shard
'shard_range': (shard_start,shard_end),
# Range of local fragment shard within bucket
'shard_bucket_range': (shard_bucket_start,shard_bucket_end),
# Range of local fragment shard within parameter
'shard_param_range': (shard_param_start,shard_param_end),
}
# Record fragment info
self.state[param]['fragments'].append(fragment)
bucket['fragments'].append(fragment)
param_start = param_end
# Initialize master param buffer
for fragment in self.state[param]['fragments']:
if fragment['in_local_shard']:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
param_start, param_end = fragment['shard_param_range']
shard_start, shard_end = fragment['shard_range']
model_param_fragment = param.view(-1)[param_start:param_end]
master_param_fragment = bucket['params_shard'][shard_start:shard_end]
master_param_fragment.copy_(model_param_fragment)
def _add_bucket(self):
"""Construct a bucket for optimizer state"""
self.state['buckets'].append({
# Parameter fragments associated with bucket
'fragments': [],
# Gradient buffers
'grads_shard': None,
'grads_bucket': None,
'curr_grads_shard': None, # For current micro-batch
# Optimizer state
'params_shard': torch.zeros([self.shard_size], dtype=self.dtype, device=self.device),
'exp_avg_shard': torch.zeros([self.shard_size], dtype=self.dtype, device=self.device),
'exp_avg_sq_shard': torch.zeros([self.shard_size], dtype=self.dtype, device=self.device),
# Status of parameter gradients
'gradient_status': self.GradientStatus.READY,
# Distributed request object for gradient synchronization
'grad_sync_request': None,
})
def zero_grad(self, set_to_none=True):
"""Clear parameter gradients"""
for group in self.param_groups: for group in self.param_groups:
self._param_group = group for param in group['params']:
prev = None if param.grad is None or set_to_none:
beta1, beta2 = group['betas'] param.grad = None
bias_correction = 1 if group['bias_correction'] else 0 else:
eps = group['eps'] param.grad.zero_()
weight_decay = group['weight_decay'] for bucket in self.state['buckets']:
for p in group['params']: bucket['grads_shard'] = None
# broadcast from rank 0 of current process group bucket['grads_bucket'] = None
torch.distributed.broadcast(p, src=self._available_ranks[0], group=self._current_process_group) bucket['curr_grads_shard'] = None
if not p.requires_grad: bucket['gradient_status'] = self.GradientStatus.READY
continue self._grads_generated = set()
self._model_params.append(p)
# Multiple param groups support: def _start_grad_copy(self, param):
# store one hyperparam item per parameter tensor """Copy parameter gradient to corresponding buckets
self._group_properties.append((
beta1, The copy is deferred if using a fused copy kernel.
beta2,
bias_correction,
eps,
weight_decay
))
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._grads = []
if self._overlap_reductions:
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
#print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
# initialize master weights, moments buffers if not loaded from checkpoint
if self._fp32_p is None:
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
def __blockify(p):
return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._shard_size
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
def __packed_chunkify(p):
# in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params = []
self._contrib_tensor_list = []
self._contrib_group_properties = []
self._non_parallel_grads = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size
for (p, grads_info, group_props) in zip(self._model_params, self._grads_info, self._group_properties):
flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
if clipped_start < clipped_end:
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._group_rank:
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
if not self._resume_from_checkpoint:
master_param_fragment.copy_(model_param_fragment)
self._contrib_group_properties.append(group_props)
self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, g, p_copy
if self._model_parallel and hasattr(p, 'model_parallel') and not p.model_parallel:
self._non_parallel_grads.append(opti_state_g_fragment)
p, m, v, g, p_copy = list(zip(*self._contrib_tensor_list))
self._contrib_tensor_list = [p, m, v, g, p_copy]
math_type = self._fp32_p.dtype
beta1, beta2, bias_correction, epsilon, decay = list(zip(*self._contrib_group_properties))
self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')
self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')
self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')
self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')
self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')
p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out]
if self._num_groups > 1:
self._ar_pg = []
for i in range(self._num_process_groups):
# gather global ranks of all members of the current process group
ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
for j in range(self._group_size):
ar_idx = [j+k*self._group_size for k in range(self._num_groups)]
ar_rank = [ranks[k] for k in ar_idx]
#if self._global_rank in ar_rank:
# print("group for all reduce, ranks:", ar_rank)
for _ in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ar_rank)
if self._global_rank in ar_rank:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
self._rs_pg, rs_ranks = [],[]
for i in range(self._num_process_groups):
ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
for j in range(self._num_groups):
rs_idx = [j*self._group_size+k for k in range(self._group_size)]
rs_rank = [ranks[k] for k in rs_idx]
#if self._global_rank in rs_rank:
# print("group for reduce scatter, ranks:", rs_rank)
for _ in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=rs_rank)
if self._global_rank in rs_rank:
self._rs_pg.append(grp)
if self._compute_L2_grad_norm:
l2_grad_norm_pg = torch.distributed.new_group(ranks=rs_rank)
if self._global_rank in rs_rank:
self._l2_grad_norm_pg = l2_grad_norm_pg
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for i in range(self._num_process_groups):
ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
for j in range(self._num_groups):
ag_rank = rs_ranks[j]
#if self._global_rank in ag_rank:
# print("group for all gather, ranks:", ag_rank)
for _ in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ag_rank)
if self._global_rank in ag_rank:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream()
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def _init_everything(self):
if not self._init_done:
self._first_step_init()
self._init_done = True
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
# Optionally compute L2 grad norm
if self._compute_L2_grad_norm and block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
# for model_parallel_rank=0, keep all gradients
# for the rest, subtract non_parallel gradients
if self._model_parallel and self._process_group_id: # non zero model_parallel_rank
non_parallel_grad_norm_sq = torch.zeros([1], device='cuda')
if len(self._non_parallel_grads): # non parallel grads exit
non_parallel_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm,
self._overflow_buf,
[self._non_parallel_grads], False)[0]**2
torch.distributed.all_reduce(non_parallel_grad_norm_sq, group=self._l2_grad_norm_pg)
l2_grad_norm_sq = l2_grad_norm_sq - non_parallel_grad_norm_sq
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self._global_scale
if self._clip_grad_norm and self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
self._step += 1
multi_tensor_applier(distributed_adam_cuda.multi_tensor_fused_adam,
self._overflow_buf,
self._contrib_tensor_list, # p, m, v, g, p_copy
self._contrib_beta1,
self._contrib_beta2,
self._contrib_bias_correction,
self._contrib_epsilon,
self._contrib_weight_decay,
self._param_group['lr'],
combined_scale,
self._step,
self.eps_mode)
def _pipeline_step(self):
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel()
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
""" """
self._global_scale = global_scale
@property # Copy param grad to buckets
def global_scale(self): for fragment in self.state[param]['fragments']:
return self._global_scale
# Get fragment position
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
grad_start, grad_end = fragment['param_range']
bucket_start, bucket_end = fragment['bucket_range']
# Set reduction status
if bucket['gradient_status'] == self.GradientStatus.SYNCING:
self._finish_bucket_grad_sync()
bucket['gradient_status'] = self.GradientStatus.PARTIALLY_FILLED
# Allocate gradient buffer if needed
if bucket['grads_bucket'] is None:
bucket['grads_bucket'] = torch.zeros(
[self.bucket_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
# Copy param grad to bucket
if param.grad is not None:
fragment_in = param.grad.view(-1)[grad_start:grad_end]
fragment_out = bucket['grads_bucket'][bucket_start:bucket_end]
self._grads_to_copy.append((fragment_in, fragment_out))
# Free param grad buffer
if not self.fused_grad_copy:
self._finish_grad_copy()
param.grad = None
# Update reduction statuses
self._grads_generated.add(param)
for fragment in self.state[param]['fragments']:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
is_filled = True
for other_fragment in reversed(bucket['fragments']):
param_group_id = other_fragment['param_group_id']
param_id = other_fragment['param_id']
other_param = self.param_groups[param_group_id]['params'][param_id]
if other_param not in self._grads_generated:
is_filled = False
break
if is_filled:
bucket['gradient_status'] = self.GradientStatus.FULLY_FILLED
def _finish_grad_copy(self):
"""Make sure that parameter gradients have been copied to buckets
Performs any deferred copies from _start_grad_copy.
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
""" """
has_overflow = self._has_overflow if self._grads_to_copy:
self._has_overflow = False scale = 1/self.world_size if self.average_grad_sync else 1.0
return has_overflow if self.fused_grad_copy:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
@property multi_tensor_applier(
def peek_overflow(self): amp_C.multi_tensor_scale,
"""Check if overflows were detected by any call to step(...) method. dummy_overflow_buf,
Does not clear overflow flag. list(zip(*self._grads_to_copy)),
""" scale,
return self._has_overflow )
else:
for fragment_in, fragment_out in self._grads_to_copy:
fragment_out.add_(fragment_in, alpha=scale)
self._grads_to_copy = []
def _force_bucket_grad_sync(self):
"""Ensure that all gradient buckets are synchronized"""
# Synchronize all unsynchronized buckets
self._finish_bucket_grad_sync()
self._start_bucket_grad_sync([
bucket for bucket in self.state['buckets']
if bucket['gradient_status'] != self.GradientStatus.READY
])
self._finish_bucket_grad_sync()
# Fill any unfilled buckets with zeros
for bucket in self.state['buckets']:
if bucket['grads_shard'] is None:
bucket['grads_shard'] = torch.zeros(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
# Reset set of generated gradients
self._grads_generated = set()
def _try_start_bucket_grad_sync(self):
"""Launches gradient synchronization if enough buckets are ready
Gradient synchronization is asynchronous. Launches gradient
synchronization if all gradients have been generated or if
there are enough buckets ready to fill pipeline.
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
""" """
if start >= 0 and start < end: if len(self._grads_generated) == self._num_grads:
out_p = output_params[start:end] self._force_bucket_grad_sync()
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property
def L2_grad_norm(self):
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
else: else:
return None filled_buckets = [
bucket
for bucket in self.state['buckets'][:-1]
if bucket['gradient_status'] == self.GradientStatus.FULLY_FILLED
]
pipeline_size = (len(filled_buckets) // self.pipeline_size) * self.pipeline_size
if pipeline_size > 0:
self._start_bucket_grad_sync(filled_buckets[:pipeline_size])
def _start_bucket_grad_sync(self, buckets):
"""Synchronize gradients in buckets
Gradient synchronization is asynchronous. Involves
reduce-scatter over distributed process group and allreduce
over redundant process group.
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
""" """
self._init_everything() self._finish_bucket_grad_sync()
if self._last_step: self._finish_grad_copy()
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated): # Reduce gradients
if not grad_generated: for stream in self._pipeline_streams:
grad_info = self._grads_info[param_i] stream.wait_stream(torch.cuda.current_stream())
param_offset = grad_info["param_offset"] for i, bucket in enumerate(buckets):
param_size = grad_info["param_grads_size"] bucket['gradient_status'] = self.GradientStatus.SYNCING
self._flat_grads[param_offset:param_offset+param_size].zero_() stream = self._pipeline_streams[i % self.pipeline_size]
self._grads_generated[param_i] = True with torch.cuda.stream(stream):
if self._last_step or not self._overlap_reductions: # Reduce-scatter over distributed process group
# nothing done so far, run full pipeline after reductions if self.distributed_size == 1:
for block_id in range(self._num_blocks-1,-1,-1): bucket['curr_grads_shard'] = bucket['grads_bucket']
self._pipeline_block_reductions(block_id) bucket['grad_sync_request'] = None
else:
if self._compute_L2_grad_norm: bucket['curr_grads_shard'] = torch.zeros(
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st) [self.shard_size],
dtype=self.grad_sync_dtype,
self._current_block = self._num_blocks device=self.device,
self._grads_generated = [False]*len(self._grads_info) )
grads_bucket_shards = [
def step(self, closure=None): bucket['grads_bucket'][i*self.shard_size:(i+1)*self.shard_size]
loss = None for i in range(self.distributed_size)
if closure is not None: ]
loss = closure() if self._reduce_scatter_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
bucket['grad_sync_request'] = (
torch.distributed.reduce_scatter(
bucket['curr_grads_shard'],
grads_bucket_shards,
group=self.distributed_process_group,
async_op=True,
**no_copy_kwarg,
)
)
# All-reduce over redundant process group
# Note: Assuming reduce-scatters are finished in the
# order they are submitted, all-reduces should be
# submitted in a consistent order. There could be race
# conditions if wait doesn't finish in order.
if self.redundant_size > 1:
if bucket['grad_sync_request'] is not None:
bucket['grad_sync_request'].wait()
bucket['grad_sync_request'] = (
torch.distributed.all_reduce(
bucket['curr_grads_shard'],
group=self.redundant_process_group,
async_op=True,
)
)
def _finish_bucket_grad_sync(self):
"""Wait for any gradient synchronizations that are in progress"""
for bucket in self.state['buckets']:
if bucket['gradient_status'] == self.GradientStatus.SYNCING:
# Finish asynchronous communication
if bucket['grad_sync_request'] is not None:
bucket['grad_sync_request'].wait()
bucket['grad_sync_request'] = None
# Accumulate gradient in local shard
if bucket['grads_shard'] is None:
bucket['grads_shard'] = bucket['curr_grads_shard']
else:
bucket['grads_shard'].add_(bucket['curr_grads_shard'])
# Deallocate buffers for gradient synchronization
bucket['grads_bucket'] = None
bucket['curr_grads_shard'] = None
# Reset status
bucket['gradient_status'] = self.GradientStatus.READY
@contextlib.contextmanager
def no_sync(self):
"""Disable overlapped gradient synchronization
Context manager that is similar to
torch.nn.parallel.DistributedDataParallel.no_sync. The
gradients can be synchronized by calling grad_sync or step. If
overlapped gradient synchronization is enabled, gradients can
also be synchronized by leaving the context and performing a
backward pass.
"""
old_overlap_grad_sync = self.overlap_grad_sync
self.overlap_grad_sync = False
try:
yield
finally:
self.overlap_grad_sync = old_overlap_grad_sync
def grad_sync(self):
"""Ensure that all gradients are synchronized"""
for bucket in self.state['buckets']:
for fragment in bucket['fragments']:
param_group_id = fragment['param_group_id']
param_id = fragment['param_id']
param = self.param_groups[param_group_id]['params'][param_id]
if param.grad is not None:
self._start_grad_copy(param)
self._try_start_bucket_grad_sync()
self._force_bucket_grad_sync()
def grad_norm(self):
"""Compute L2 norm of all parameter gradients
If model parallelism is enabled, exclude non-parallel
gradients on non-root processes. This is Megatron-specific, so
should this logic be moved elsewhere?
self._pipeline_step() """
with torch.cuda.stream(self._completion_st): # Make sure that gradients have been reduced
# Copy self._new_params to model params self.grad_sync()
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt, # Evaluate L2 norm of distributed gradients
self._overflow_buf, dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
self._packed_flat_to_model_params) grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[[bucket['grads_shard'] for bucket in self.state['buckets']]],
False,
)[0] ** 2
torch.distributed.all_reduce(
grad_norm_sq,
group=self.distributed_process_group,
)
# If model parallelism is enabled, subtract non-parallel
# gradients on non-root processes
if self.model_parallel and self.model_parallel_rank:
non_parallel_grads = []
for bucket in self.state['buckets']:
for fragment in bucket['fragments']:
if fragment['in_local_shard']:
param_group_id = fragment['param_group_id']
param_id = fragment['param_id']
param = self.param_groups[param_group_id]['params'][param_id]
if (hasattr(param, 'model_parallel')
and not param.model_parallel):
shard_start, shard_end = fragment['shard_range']
non_parallel_grads.append(
bucket['grads_shard'][shard_start:shard_end]
)
if non_parallel_grads:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
non_parallel_grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[non_parallel_grads],
False,
)[0] ** 2
else:
non_parallel_grad_norm_sq = torch.zeros([1], device=self.device)
torch.distributed.all_reduce(
non_parallel_grad_norm_sq,
group=self.distributed_process_group,
)
grad_norm_sq -= non_parallel_grad_norm_sq
return grad_norm_sq.sqrt()
def step(self, closure=None, scale=1.):
"""Apply Adam optimizer step
Arguments:
closure (callable, optional): closure to recompute loss
(default: None)
scale (float, optional): scaling factor to divide
gradients (default: 1.0)
torch.cuda.current_stream().wait_stream(self._completion_st) """
self.state['step'] += 1
loss = None
if closure is not None:
loss = closure()
self._reductions_works = [None]*self._num_blocks # Make sure that gradients have been reduced
self._allgather_works = [None]*self._num_blocks self.grad_sync()
# Scale gradient if L2 norm is too large
if self.max_grad_norm > 0:
grad_norm = self.grad_norm().item()
if (math.isfinite(grad_norm)
and grad_norm / scale > self.max_grad_norm):
scale = grad_norm / self.max_grad_norm
# Apply optimizer step to each bucket and synchronize params
current_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(current_stream)
for i, bucket in enumerate(self.state['buckets']):
stream = self._pipeline_streams[i % self.pipeline_size]
with torch.cuda.stream(stream):
# Buffer for param sync
params_shard_copy = torch.zeros(
[self.shard_size],
dtype=self.param_sync_dtype,
device=self.device,
)
# Find param fragments in local shard
buffers = collections.defaultdict(list) # p, m, v, g, p_copy
for fragment in bucket['fragments']:
if fragment['in_local_shard']:
param_group_id = fragment['param_group_id']
shard_start, shard_end = fragment['shard_range']
buffers[param_group_id].append([
bucket['params_shard'][shard_start:shard_end],
bucket['exp_avg_shard'][shard_start:shard_end],
bucket['exp_avg_sq_shard'][shard_start:shard_end],
bucket['grads_shard'][shard_start:shard_end],
params_shard_copy[shard_start:shard_end],
])
# Fuse param fragments if possible
if len(buffers) == 1:
group_id = list(buffers.keys())[0]
buffers[group_id] = [(
bucket['params_shard'],
bucket['exp_avg_shard'],
bucket['exp_avg_sq_shard'],
bucket['grads_shard'],
params_shard_copy,
)]
# Apply optimizer step to each param group
for group_id, group_buffers in buffers.items():
# Get param group configs
group = self.param_groups[group_id]
beta1, beta2 = group['betas']
bias_correction = 1 if group['bias_correction'] else 0
eps = group['eps']
weight_decay = group['weight_decay']
# Copy param group configs to GPU
num_fragments = len(group_buffers)
beta1 = torch.full([num_fragments], beta1, dtype=self.dtype, device='cuda')
beta2 = torch.full([num_fragments], beta2, dtype=self.dtype, device='cuda')
bias_correction = torch.full([num_fragments], bias_correction, dtype=torch.int32, device='cuda')
eps = torch.full([num_fragments], eps, dtype=self.dtype, device='cuda')
weight_decay = torch.full([num_fragments], weight_decay, dtype=self.dtype, device='cuda')
# Apply Adam step
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
distributed_adam_cuda.multi_tensor_fused_adam,
dummy_overflow_buf,
list(zip(*group_buffers)),
beta1,
beta2,
bias_correction,
eps,
weight_decay,
group['lr'],
scale,
self.state['step'],
1, # Set to 0 to apply eps inside sqrt
)
# Deallocate buffers
del buffers
bucket['grads_shard'] = None
# Allgather updated parameters
if self.distributed_size == 1:
params_bucket = params_shard_copy
else:
params_bucket = torch.zeros(
[self.bucket_size],
dtype=self.param_sync_dtype,
device=self.device,
)
params_bucket_shards = [
params_bucket[i*self.shard_size:(i+1)*self.shard_size]
for i in range(self.distributed_size)
]
params_bucket_shards[self.distributed_rank].copy_(params_shard_copy)
if self._all_gather_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
torch.distributed.all_gather(
params_bucket_shards,
params_bucket_shards[self.distributed_rank],
group=self.distributed_process_group,
**no_copy_kwarg,
)
del params_shard_copy
# Copy values to param buffers
params_in = []
params_out = []
for fragment in bucket['fragments']:
param_group_id = fragment['param_group_id']
param_id = fragment['param_id']
param = self.param_groups[param_group_id]['params'][param_id]
bucket_start, bucket_end = fragment['bucket_range']
param_start, param_end = fragment['param_range']
params_in.append(params_bucket[bucket_start:bucket_end])
params_out.append(param.view(-1)[param_start:param_end])
if params_in:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
dummy_overflow_buf,
[params_in, params_out],
)
del params_bucket, params_in, params_out
# Synchronize pipeline streams
for stream in self._pipeline_streams:
current_stream.wait_stream(stream)
return loss return loss
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
# save step, master weights and first/second moments
state_dict = {}
state_dict['step'] = self._step
state_dict['fp32_p'] = self._fp32_p
state_dict['fp32_m'] = self._fp32_m
state_dict['fp32_v'] = self._fp32_v
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# restore step, master weights and first/second moments
self._step = state_dict['step']
self._fp32_p = state_dict['fp32_p'].to(device="cuda")
self._fp32_m = state_dict['fp32_m'].to(device="cuda")
self._fp32_v = state_dict['fp32_v'].to(device="cuda")
self._resume_from_checkpoint = True
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(DistributedFusedAdamV2, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
p_offset = 0
p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = []
self._grad_accs = []
for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = []
if self._overlap_reductions:
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._shard_size = self._block_size // self._group_size
self._chunk_size = self._shard_size // self._num_chunks
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d, self._chunk_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size,self._chunk_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._chunk_size
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_shards = [__shardify(block) for block in list_of_blocks]
list_of_list_of_list_of_chunks = [[__chunkify(shard) for shard in shards] for shards in list_of_list_of_shards]
return list_of_blocks, list_of_list_of_shards, list_of_list_of_list_of_chunks
self._flat_grads_blocks, self._flat_grads_shards, self._flat_grads_chunks = _flat_split(self._flat_grads)
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
def __blockify(p):
return [p[block_id*self._num_chunks*self._chunk_size:(block_id+1)*self._num_chunks*self._chunk_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._chunk_size
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
def __packed_chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
# current arrangement
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._shard_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._new_params_mega_chunks [x self._num_chunks, self._shard_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._fp32_p_chunks [x self._num_chunks, self._shard_size]
# each chunk contains one shard
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# for chunk_id in range(self._num_chunks):
# works[chunk_id] = torch.distributed.reduce_scatter(self._flat_grads_chunks[block_id][chunk_id], self._fp16_g_chunks[block_id][chunk_id], ...)
#
# ----------------------------------------------------------------------------------------
#
# new arrangement
#
# NB! New equations for self._shard_size and self._chunk_size
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._chunk_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._new_params_mega_chunks [x self._num_chunks, self._chunk_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._fp32_p_chunks [x self._num_chunks, self._chunk_size]
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# work = torch.distributed.reduce_scatter(self._flat_grads_blocks[block_id], self._fp16_g[block_id], ...)
# for chunk_id in range(self._num_chunks):
# work.wait()
# works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id], ...)
# or
# work.wait()
# works[0] = torch.distributed.all_reduce(self._fp16_g_blocks[block_id], ...)
#
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
flat_shard_start = (block_id * self._group_size + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size
for p, grads_info in zip(self._model_params, self._grads_info):
flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
if clipped_start < clipped_end:
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_blocks[shard_id][block_id][shard_offset:shard_offset+grad_length]
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_blocks[block_id][shard_offset:shard_offset+grad_length]
print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
master_param_fragment.copy_(model_param_fragment)
p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out]
self._distributed_weight_update = distributed_weight_update # Is this still needed?
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream()
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
rs_stream = self._rs_st[block_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
rs_work = torch.distributed.reduce_scatter(self._fp16_g_blocks[block_id],self._flat_grads_shards[block_id],group=self._rs_pg[block_id%self._num_rs_pg],async_op=True,no_copy=True)
for chunk_id in range(self._num_chunks):
works[chunk_id] = rs_work
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
rs_work.wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
# Optionally compute L2 grad norm
if self._compute_L2_grad_norm and block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def _pipeline_block_step(self, block_id):
# Call step kernel once per block
ag_stream = self._ag_st[block_id%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p_blocks[block_id],
self._fp16_p_blocks[block_id],
self._fp32_m_blocks[block_id],
self._fp32_v_blocks[block_id],
self._fp16_g_blocks[block_id])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if block_id == 0:
for other_ag_stream in self._ag_st:
self._completion_st.wait_stream(other_ag_stream)
with torch.cuda.stream(self._completion_st):
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _pipeline_step(self):
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p,
self._fp16_p,
self._fp32_m,
self._fp32_v,
self._fp16_g)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
if self._full_pipeline:
self._pipeline_block_step(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._has_overflow
self._has_overflow = False
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property
def L2_grad_norm(self):
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
else:
return None
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
if not grad_generated:
grad_info = self._grads_info[param_i]
param_offset = grad_info["param_offset"]
param_size = grad_info["param_grads_size"]
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id)
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def revert_step(self):
"""Revert effect of previously calling partial_step.
"""
# Call undo kernel once per step
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p,
self._fp32_m,
self._fp32_v,
self._fp16_g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def step(self, closure=None, skip_overflow_check=False):
loss = None
if closure is not None:
loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline:
self._pipeline_step()
with torch.cuda.stream(self._completion_st):
# Check for overflow
# Store state for loss scaler calculation
has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if has_overflow:
self.revert_step()
else:
# Copy self._new_params to model params
for p in self._model_params: self.state[p]['step'] += 1
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params)
torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
return loss
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdamV3(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(DistributedFusedAdamV3, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
p_offset = 0
p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = []
self._grad_accs = []
for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = []
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._shard_size = self._total_param_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._flat_params = torch.zeros_like(self._flat_grads)
def _flat_split(flat):
def __flat_blockify(flat):
return [flat[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __flat_shardify(flat):
return [flat[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
return __flat_blockify(flat), __flat_shardify(flat)
self._flat_grads_blocks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._flat_params_blocks, self._flat_params_shards = _flat_split(self._flat_params)
# master params
self._fp32_p = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
# copy model params to flat_params and set_ model params to flat_params.
self._individual_flat_grads = []
with torch.no_grad():
for p, grads_info in zip(self._model_params, self._grads_info):
start = grads_info["param_offset"]
end = start + grads_info["param_grads_size"]
flat_p = self._flat_params[start:end].view_as(p)
flat_p.copy_(p)
p.set_(flat_p)
flat_grad = self._flat_grads[start:end]
self._individual_flat_grads.append(flat_grad)
self._fp32_p.copy_(self._flat_params_shards[self._rank_in_group].float())
self._dwu_st = torch.cuda.Stream()
self._l2_grad_norm_st = torch.cuda.Stream()
for group_i in range(self._num_groups):
ranks = [group_i*self._group_size+local_rank for local_rank in range(self._group_size)]
pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg = pg
torch.distributed.all_reduce(self._overflow_buf, group=self._ag_pg)
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
@property
def has_overflow(self):
return True if not self.L2_grad_norm is None and not math.isfinite(self.L2_grad_norm) else False
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step and self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._dwu_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._dwu_st):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
torch.distributed.all_reduce(self._flat_grads_blocks[block_id])
if block_id == 0:
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, flat_grad in enumerate(self._individual_flat_grads):
if not self._grads_generated[param_i]:
flat_grad.zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
self._dwu_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._dwu_st):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
torch.distributed.all_reduce(self._flat_grads)
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def step(self, closure=None, skip_overflow_check=False):
loss = None
if closure is not None:
loss = closure()
with torch.cuda.stream(self._dwu_st):
self.__launch_step_kernel(
self._fp32_p,
self._flat_params_shards[self._rank_in_group],
self._fp32_m,
self._fp32_v,
self._flat_grads_shards[self._rank_in_group])
torch.distributed.all_gather(self._flat_params_shards, self._flat_params_shards[self._rank_in_group], group=self._ag_pg, no_copy=True)
for p in self._model_params: self.state[p]['step'] += 1
torch.cuda.current_stream().wait_stream(self._dwu_st)
return loss
import argparse import argparse
import os
import random import random
import sys
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from apex import amp
from apex.optimizers import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__(self, args): def __init__(self, args):
super(TestModel, self).__init__() super(TestModel, self).__init__()
self.linear = torch.nn.Sequential(*[
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)]) torch.nn.Linear(args.dim, args.dim)
for _ in range(args.layers)
])
def forward(self, x): def forward(self, x):
return self.linear(x) y = 0
for i, l in enumerate(self.linear):
y += (i+1) * l(x)
return y
def setup(args): def setup(args):
## Model
ref_model = TestModel(args).cuda()
dist_model = TestModel(args).cuda()
# Same weights # Construct models with same parameters
ref_model = TestModel(args).float().cuda()
dist_model = TestModel(args).float().cuda()
with torch.no_grad(): with torch.no_grad():
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()): for ref_param, dist_param in zip(dist_model.parameters(),
dp.data.copy_(rp.data) ref_model.parameters()):
dist_param.data.copy_(ref_param.data)
dist_model = dist_model.half() ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model,
device_ids=[args.rank],
## Optimizer output_device=args.rank,
# same hyperparameters )
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args) # Construct optimizers with same hyperparameters
optim_args = { 'lr': 1, 'betas': (0.5,0.75), 'eps': 0.1, 'weight_decay': 0.1 }
dist_opt_args = ref_opt_args.copy() ref_optim = torch.optim.AdamW(
dist_opt_args.update( {'overlap_reductions' : False} ) [
dist_opt_args.update( {'process_group_size' : args.n_gpu} ) {'params': list(ref_model.parameters())[1::2], 'lr': 0.5},
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} ) {'params': list(ref_model.parameters())[0::2]},
dist_opt_args.update( {'dwu_num_blocks' : 1} ) ],
dist_opt_args.update( {'dwu_num_chunks' : 1} ) **optim_args,
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args) )
dist_opt.set_global_scale(1.) dist_optim = DistributedFusedAdam(
[
## amp-init {'params': list(dist_model.parameters())[1::2], 'lr': 0.5},
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'} {'params': list(dist_model.parameters())[0::2]},
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args) ],
bucket_cap_mb=71/(4*1024*1024),
**optim_args,
## DDP )
ref_model = DDP(ref_model, device_ids=[args.rank])
with torch.no_grad(): return ref_model, ref_optim, dist_model, dist_optim
for dp in dist_model.parameters():
torch.distributed.broadcast(dp.data, src=0)
for rp in ref_model.parameters():
torch.distributed.broadcast(rp.data, src=0)
torch.cuda.synchronize()
torch.distributed.barrier()
if get_rank() == 0:
print(f'dist opt with {args.n_gpu} GPUs')
return ref_model, ref_opt, dist_model, dist_opt
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=20) parser.add_argument('--steps', type=int, default=3)
parser.add_argument('--batch', type=int, default=32) parser.add_argument('--batch', type=int, default=5)
parser.add_argument('--dim', type=int, default=4) parser.add_argument('--dim', type=int, default=7)
parser.add_argument('--layers', type=int, default=2) parser.add_argument('--layers', type=int, default=11)
parser.add_argument('--bias', action='store_true') parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--atol', type=float, default=1e-3) parser.add_argument('--rtol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1)
parser.add_argument('--dwu_group_size', type=float, default=1)
args = parser.parse_args() args = parser.parse_args()
return args return args
def setup_env(args): def setup_env(args):
torch.cuda.set_device(args.local_rank)
# Initialize NCCL
local_rank = args.local_rank
if local_rank < 0:
local_rank = int(os.getenv('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank % torch.cuda.device_count())
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
args.n_gpu = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
seed = 42 + get_rank()
# Initialize RNG
seed = 42 + args.rank
random.seed(seed) random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
return args return args
def get_rank():
return torch.distributed.get_rank()
def main(): def main():
args = parse_args() args = parse_args()
args = setup_env(args) args = setup_env(args)
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }
torch.set_printoptions(precision=16) torch.set_printoptions(precision=16)
ref_model, ref_opt, dist_model, dist_opt = setup(args) def assert_allclose(ref_x, dist_x, message):
message = (
# lazy_init not called yet, initialize stash f'Rank {args.rank}: {message}\n'
stash = ref_opt._amp_stash f'Reference Adam: {ref_x}\n'
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], [] f'Distributed Adam: {dist_x}\n'
f'Relative error: {torch.abs((ref_x-dist_x)/ref_x)}\n'
# make sure everything from _first_step_init_ is ready before training )
# e.g. registering allreduce_hook assert torch.allclose(ref_x, dist_x, atol=args.atol, rtol=args.rtol), message
# so that gradients are copied/reduced when necessary
dist_opt._init_everything() # Train model with data-parallelism and ZeRO
ref_model, ref_optim, dist_model, dist_optim = setup(args)
for i in range(args.steps): for step in range(args.steps):
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
x_dist = x_ref.clone().detach().requires_grad_(True) # Synthetic data
x = torch.randn(args.batch, args.dim).cuda()
if get_rank() == 0: dy = torch.randn_like(x).cuda()
print(f'[{i}] Checking input')
#print("x_ref:", x_ref.flatten()[:10]) # Reference implementation
#print("x_dist:", x_dist.flatten()[:10]) ref_optim.zero_grad()
assert(torch.allclose(x_ref, x_dist, **tol_args)) x_ref = x.detach().clone().requires_grad_(True)
y_ref = ref_model(x_ref)
y_ref.backward(dy)
ref_optim.step()
y_ref = ref_model(x_ref).half() # Distributed implementation
dist_optim.zero_grad()
x_dist = x.detach().clone().requires_grad_(True)
y_dist = dist_model(x_dist) y_dist = dist_model(x_dist)
if get_rank() == 0:
print(f'[{i}] Checking output')
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert(torch.allclose(y_ref, y_dist, **tol_args))
dy = torch.randn_like(y_ref)
y_ref.backward(dy)
y_dist.backward(dy) y_dist.backward(dy)
dist_optim.step()
if get_rank() == 0: # Check values
print(f'[{i}] Checking gradients')
torch.distributed.barrier()
torch.cuda.synchronize()
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))
# gradient all-reduce within distributed optimizer
dist_opt.complete_reductions()
if get_rank() == 0:
print(f'[{i}] Stepping')
ref_opt.step()
dist_opt.step()
torch.cuda.synchronize() torch.cuda.synchronize()
torch.distributed.barrier() torch.distributed.barrier()
print('Checking new weights') assert_allclose(
if get_rank() == 0: y_ref,
print("ref param:", ref_model.module.linear[0].weight) y_dist,
print("dist param:", dist_model.linear[0].weight) f'inconsistent output in step {step}',
)
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())): assert_allclose(
if not torch.allclose(rp, dp, **tol_args): x_ref.grad,
if get_rank() == 0: x_dist.grad,
print(f'Rank: {get_rank()}, Param: {i}') f'inconsistent input grad in step {step}',
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}') )
print(rp) for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
print(dp) dist_model.parameters())):
assert_allclose(
print(torch.abs(rp-dp) > tol_args['atol']) ref_param,
sys.exit(0) dist_param,
f'inconsistent param {i} in step {step}',
# zero grads )
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
rp.grad = None
dp.grad = None
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment