Unverified Commit cd499737 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add features to distributed Adam for Megatron support (#1414)

* Add features to distributed Adam for Megatron support

Support gradient clipping, gradient scaling, FP32 grad accumulation, and multiple dtypes and devices.

* Restore closure arg to distributed Adam

Review suggestion from @crcrpar
parent bf3c008e
...@@ -9,7 +9,7 @@ import threading ...@@ -9,7 +9,7 @@ import threading
import torch import torch
import amp_C import amp_C
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
"""AdamW optimizer with ZeRO algorithm. """AdamW optimizer with ZeRO algorithm.
...@@ -52,7 +52,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -52,7 +52,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
parameter synchronization (default: same as parameter synchronization (default: same as
grad_sync_dtype) grad_sync_dtype)
device (torch.device, optional): device for optimizer state device (torch.device, optional): device for optimizer state
(default: cuda). Currently only supports GPU. (default: cuda). Currently only supports GPU with one GPU
per process.
process_group (torch.distributed.ProcessGroup, optional): process_group (torch.distributed.ProcessGroup, optional):
parallel processes participating in optimizer (default: parallel processes participating in optimizer (default:
default group in torch.distributed). This group is default group in torch.distributed). This group is
...@@ -64,10 +65,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -64,10 +65,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
redundant_process_group (torch.distributed.ProcessGroup, redundant_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to replicate optimizer state optional): parallel processes to replicate optimizer state
over (default: group only containing calling process) over (default: group only containing calling process)
model_parallel (bool, optional): whether model parallelism is
used (default: False)
model_parallel_rank (int, optional): rank in model-parallel
process group (default: 0)
average_grad_sync (bool, optional): whether to use average average_grad_sync (bool, optional): whether to use average
reduction for gradient synchronization rather than sum reduction for gradient synchronization rather than sum
(default: True) (default: True)
...@@ -75,14 +72,9 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -75,14 +72,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
gradient synchronization with backward pass compute gradient synchronization with backward pass compute
(default: True) (default: True)
bucket_cap_mb (float, optional): bucket size in megabytes bucket_cap_mb (float, optional): bucket size in megabytes
(default: 15) (default: 100)
pipeline_size (int, optional): number of buckets to pipeline_size (int, optional): number of buckets to
synchronize simultaneously (default: 2) synchronize simultaneously (default: 2)
fused_grad_copy (bool, optional): whether to used fused kernel
to fill bucket with gradients (default: False). Requires
all parameters to have the same data type.
max_grad_norm (float, optional): maximum L2 norm for gradient
clipping (default: disabled)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -105,6 +97,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -105,6 +97,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Asynchronous reduction is in progress # Asynchronous reduction is in progress
SYNCING = enum.auto() SYNCING = enum.auto()
_step_supports_amp_scaling = True
def __init__(self, def __init__(self,
params, params,
lr=1e-3, lr=1e-3,
...@@ -120,14 +114,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -120,14 +114,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
process_group=None, process_group=None,
distributed_process_group=None, distributed_process_group=None,
redundant_process_group=None, redundant_process_group=None,
model_parallel=False,
model_parallel_rank=0,
average_grad_sync=True, average_grad_sync=True,
overlap_grad_sync=True, overlap_grad_sync=True,
bucket_cap_mb=100, bucket_cap_mb=100,
pipeline_size=2, pipeline_size=2,
fused_grad_copy=False,
max_grad_norm=0.,
): ):
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay) betas=betas, eps=eps, weight_decay=weight_decay)
...@@ -142,11 +132,11 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -142,11 +132,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grad_sync_dtype = dtype grad_sync_dtype = dtype
if param_sync_dtype is None: if param_sync_dtype is None:
param_sync_dtype = grad_sync_dtype param_sync_dtype = grad_sync_dtype
valid_dtypes = [ supported_dtypes = [
(torch.float32, torch.float16, torch.float16), (torch.float32, torch.float16, torch.float16),
(torch.float32, torch.float32, torch.float32), (torch.float32, torch.float32, torch.float32),
] ]
if (dtype, grad_sync_dtype, param_sync_dtype) not in valid_dtypes: if (dtype, grad_sync_dtype, param_sync_dtype) not in supported_dtypes:
raise RuntimeError( raise RuntimeError(
'Invalid dtypes for DistributedFusedAdam ' 'Invalid dtypes for DistributedFusedAdam '
f'(dtype={dtype}, ' f'(dtype={dtype}, '
...@@ -160,18 +150,18 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -160,18 +150,18 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.device = device self.device = device
# Process groups # Process groups
self.world_process_group = ( self.process_group = (
_get_default_group() _get_default_group()
if process_group is None if process_group is None
else process_group else process_group
) )
self.distributed_process_group = ( self.distributed_process_group = (
self.world_process_group self.process_group
if distributed_process_group is None if distributed_process_group is None
else distributed_process_group else distributed_process_group
) )
self.redundant_process_group = redundant_process_group self.redundant_process_group = redundant_process_group
self.world_size = torch.distributed.get_world_size(self.world_process_group) self.process_group_size = torch.distributed.get_world_size(self.process_group)
self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group) self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group)
self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group) self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group)
self.redundant_size = ( self.redundant_size = (
...@@ -179,34 +169,22 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -179,34 +169,22 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self.redundant_process_group is None if self.redundant_process_group is None
else torch.distributed.get_world_size(self.redundant_process_group) else torch.distributed.get_world_size(self.redundant_process_group)
) )
if (self.world_size != self.distributed_size * self.redundant_size): if self.process_group_size != self.distributed_size * self.redundant_size:
raise RuntimeError( raise RuntimeError(
'Invalid process group configuration ' 'Invalid process group configuration '
f'(world process group size = {self.world_size}, ' f'(process group size = {self.process_group_size}, '
f'distributed process group size = {self.distributed_size}, ' f'distributed process group size = {self.distributed_size}, '
f'redundant process group size = {self.redundant_size})' f'redundant process group size = {self.redundant_size})'
) )
self.model_parallel = model_parallel
self.model_parallel_rank = model_parallel_rank # Use average reduction for grad sync
# Grad sync options
if fused_grad_copy:
_params = list(self.parameters())
if (_params
and any(p.dtype != self.grad_sync_dtype for p in _params)
and any(p.device != self.device for p in _params)):
raise RuntimeError(
'Attempted to use fused gradient copy in DistributedFusedAdam, '
'but parameters do not all have expected '
f'dtype ({self.grad_sync_dtype}) and device ({self.device})'
)
self.average_grad_sync = average_grad_sync self.average_grad_sync = average_grad_sync
# Copy param grads to bucket as soon as available
self.greedy_grad_copy = True
# Synchronize grad buckets as soon as all grads are available
self.overlap_grad_sync = overlap_grad_sync self.overlap_grad_sync = overlap_grad_sync
# Number of buckets to synchronize at a time
self.pipeline_size = pipeline_size self.pipeline_size = pipeline_size
self.fused_grad_copy = fused_grad_copy
# Grad clipping options
self.max_grad_norm = max_grad_norm
# Determine bucket sizes # Determine bucket sizes
dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8 dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8
...@@ -230,9 +208,15 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -230,9 +208,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Objects for gradient synchronization # Objects for gradient synchronization
self._grads_generated = set() self._grads_generated = set()
self._grads_to_copy = []
self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)] self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)]
# Divide gradients by factor before optimizer step. Used for
# grad clipping and gradient scaler.
self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
# Norm of parameter gradients. Used for gradient clipping and
# gradient scaler.
self._grad_norm = None
# Check if collectives have no_copy option # Check if collectives have no_copy option
self._reduce_scatter_no_copy = ( self._reduce_scatter_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
...@@ -254,12 +238,16 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -254,12 +238,16 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_grads = 0 self._num_grads = 0
self._lock = threading.Lock() self._lock = threading.Lock()
self._grad_accs = [] self._grad_accs = []
try:
root_rank = _get_global_rank(self.process_group, 0)
except:
root_rank = 0
for param_group_id, group in enumerate(self.param_groups): for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']): for param_id, param in enumerate(group['params']):
torch.distributed.broadcast( torch.distributed.broadcast(
param, param,
src=0, src=root_rank,
group=self.world_process_group, group=self.process_group,
) )
if param.requires_grad: if param.requires_grad:
def wrapper(p, p_group_id, p_id): def wrapper(p, p_group_id, p_id):
...@@ -269,9 +257,13 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -269,9 +257,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
with self._lock: with self._lock:
if 'fragments' not in self.state[p]: if 'fragments' not in self.state[p]:
self._init_param_state(p, p_group_id, p_id) self._init_param_state(p, p_group_id, p_id)
if self.overlap_grad_sync: if self.greedy_grad_copy:
self._start_grad_copy(p) self._grad_copy(p)
self._try_start_bucket_grad_sync() if self.overlap_grad_sync:
self._try_start_bucket_grad_sync(
[p],
ignore_last_bucket=True,
)
grad_acc.register_hook(reduction_hook) grad_acc.register_hook(reduction_hook)
self._grad_accs.append(grad_acc) self._grad_accs.append(grad_acc)
wrapper(param, param_group_id, param_id) wrapper(param, param_group_id, param_id)
...@@ -415,13 +407,11 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -415,13 +407,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
bucket['curr_grads_shard'] = None bucket['curr_grads_shard'] = None
bucket['gradient_status'] = self.GradientStatus.READY bucket['gradient_status'] = self.GradientStatus.READY
self._grads_generated = set() self._grads_generated = set()
self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
self._grad_norm = None
def _start_grad_copy(self, param): def _grad_copy(self, param):
"""Copy parameter gradient to corresponding buckets """Copy parameter gradients to buckets"""
The copy is deferred if using a fused copy kernel.
"""
# Copy param grad to buckets # Copy param grad to buckets
for fragment in self.state[param]['fragments']: for fragment in self.state[param]['fragments']:
...@@ -447,62 +437,26 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -447,62 +437,26 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Copy param grad to bucket # Copy param grad to bucket
if param.grad is not None: if param.grad is not None:
fragment_in = param.grad.view(-1)[grad_start:grad_end] scale = 1/self.process_group_size if self.average_grad_sync else 1.0
fragment_out = bucket['grads_bucket'][bucket_start:bucket_end] grad_in = param.grad.detach().view(-1)[grad_start:grad_end]
self._grads_to_copy.append((fragment_in, fragment_out)) grad_out = bucket['grads_bucket'][bucket_start:bucket_end]
grad_out.add_(grad_in, alpha=scale)
# Free param grad buffer # Free param grad buffer
if not self.fused_grad_copy:
self._finish_grad_copy()
param.grad = None param.grad = None
# Update reduction statuses
self._grads_generated.add(param)
for fragment in self.state[param]['fragments']:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
is_filled = True
for other_fragment in reversed(bucket['fragments']):
param_group_id = other_fragment['param_group_id']
param_id = other_fragment['param_id']
other_param = self.param_groups[param_group_id]['params'][param_id]
if other_param not in self._grads_generated:
is_filled = False
break
if is_filled:
bucket['gradient_status'] = self.GradientStatus.FULLY_FILLED
def _finish_grad_copy(self):
"""Make sure that parameter gradients have been copied to buckets
Performs any deferred copies from _start_grad_copy.
"""
if self._grads_to_copy:
scale = 1/self.world_size if self.average_grad_sync else 1.0
if self.fused_grad_copy:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
list(zip(*self._grads_to_copy)),
scale,
)
else:
for fragment_in, fragment_out in self._grads_to_copy:
fragment_out.add_(fragment_in, alpha=scale)
self._grads_to_copy = []
def _force_bucket_grad_sync(self): def _force_bucket_grad_sync(self):
"""Ensure that all gradient buckets are synchronized""" """Ensure that all gradient buckets are synchronized"""
# Synchronize all unsynchronized buckets # Synchronize all unsynchronized buckets
self._finish_bucket_grad_sync() self._finish_bucket_grad_sync()
self._start_bucket_grad_sync([ buckets = [
bucket for bucket in self.state['buckets'] bucket for bucket in self.state['buckets']
if bucket['gradient_status'] != self.GradientStatus.READY if bucket['gradient_status'] != self.GradientStatus.READY
]) ]
self._finish_bucket_grad_sync() if buckets:
self._start_bucket_grad_sync(buckets)
self._finish_bucket_grad_sync()
# Fill any unfilled buckets with zeros # Fill any unfilled buckets with zeros
for bucket in self.state['buckets']: for bucket in self.state['buckets']:
...@@ -516,20 +470,54 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -516,20 +470,54 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Reset set of generated gradients # Reset set of generated gradients
self._grads_generated = set() self._grads_generated = set()
def _try_start_bucket_grad_sync(self): def _try_start_bucket_grad_sync(
self,
params=[],
ignore_last_bucket=True,
):
"""Launches gradient synchronization if enough buckets are ready """Launches gradient synchronization if enough buckets are ready
Gradient synchronization is asynchronous. Launches gradient Gradient synchronization is asynchronous. Launches gradient
synchronization if all gradients have been generated or if synchronization if all gradients have been generated or if
there are enough buckets ready to fill pipeline. there are enough buckets ready to fill pipeline.
Arguments:
params (iterable): parameters that have had their
gradients copied to buckets
ignore_last_bucket (bool): avoid synchronizing last bucket
until all gradients have been generated. This avoids
excessive synchronization when initializing buckets in
the first backward pass.
""" """
# Register params that have generated grads
for param in params:
self._grads_generated.add(param)
for fragment in self.state[param]['fragments']:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
is_filled = True
for other_fragment in reversed(bucket['fragments']):
param_group_id = other_fragment['param_group_id']
param_id = other_fragment['param_id']
other_param = self.param_groups[param_group_id]['params'][param_id]
if other_param not in self._grads_generated:
is_filled = False
break
if is_filled:
bucket['gradient_status'] = self.GradientStatus.FULLY_FILLED
# Launch reductions if enough buckets are ready
if len(self._grads_generated) == self._num_grads: if len(self._grads_generated) == self._num_grads:
self._force_bucket_grad_sync() self._force_bucket_grad_sync()
else: else:
all_buckets = self.state['buckets']
if ignore_last_bucket:
all_buckets = all_buckets[:-1]
filled_buckets = [ filled_buckets = [
bucket bucket
for bucket in self.state['buckets'][:-1] for bucket in all_buckets
if bucket['gradient_status'] == self.GradientStatus.FULLY_FILLED if bucket['gradient_status'] == self.GradientStatus.FULLY_FILLED
] ]
pipeline_size = (len(filled_buckets) // self.pipeline_size) * self.pipeline_size pipeline_size = (len(filled_buckets) // self.pipeline_size) * self.pipeline_size
...@@ -545,7 +533,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -545,7 +533,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
""" """
self._finish_bucket_grad_sync() self._finish_bucket_grad_sync()
self._finish_grad_copy()
# Reduce gradients # Reduce gradients
for stream in self._pipeline_streams: for stream in self._pipeline_streams:
...@@ -622,8 +609,11 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -622,8 +609,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Reset status # Reset status
bucket['gradient_status'] = self.GradientStatus.READY bucket['gradient_status'] = self.GradientStatus.READY
# Cached gradient norm has been invalidated
self._grad_norm = None
@contextlib.contextmanager @contextlib.contextmanager
def no_sync(self): def no_sync(self, greedy_grad_copy=False):
"""Disable overlapped gradient synchronization """Disable overlapped gradient synchronization
Context manager that is similar to Context manager that is similar to
...@@ -632,12 +622,21 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -632,12 +622,21 @@ class DistributedFusedAdam(torch.optim.Optimizer):
overlapped gradient synchronization is enabled, gradients can overlapped gradient synchronization is enabled, gradients can
also be synchronized by leaving the context and performing a also be synchronized by leaving the context and performing a
backward pass. backward pass.
Arguments:
greedy_grad_copy (bool, optional): copy parameter
gradients to buckets as soon as they are generated
(default: False)
""" """
old_greedy_grad_copy = self.greedy_grad_copy
old_overlap_grad_sync = self.overlap_grad_sync old_overlap_grad_sync = self.overlap_grad_sync
self.greedy_grad_copy = greedy_grad_copy
self.overlap_grad_sync = False self.overlap_grad_sync = False
try: try:
yield yield
finally: finally:
self.greedy_grad_copy = old_greedy_grad_copy
self.overlap_grad_sync = old_overlap_grad_sync self.overlap_grad_sync = old_overlap_grad_sync
def grad_sync(self): def grad_sync(self):
...@@ -648,80 +647,132 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -648,80 +647,132 @@ class DistributedFusedAdam(torch.optim.Optimizer):
param_id = fragment['param_id'] param_id = fragment['param_id']
param = self.param_groups[param_group_id]['params'][param_id] param = self.param_groups[param_group_id]['params'][param_id]
if param.grad is not None: if param.grad is not None:
self._start_grad_copy(param) self._grad_copy(param)
self._try_start_bucket_grad_sync() self._try_start_bucket_grad_sync(
[param],
ignore_last_bucket=False,
)
self._force_bucket_grad_sync() self._force_bucket_grad_sync()
def grad_norm(self): def _local_grad_norm(self, parameters=[], norm_type=2.0):
"""Compute L2 norm of all parameter gradients """Local contribution to parameter gradient norm
If model parallelism is enabled, exclude non-parallel Returns square of 2-norm. Other norms are not yet supported.
gradients on non-root processes. This is Megatron-specific, so
should this logic be moved elsewhere? If no parameters are provided, the norm is computed for all
parameters in optimizer. Provided parameters are assumed to be
in optimizer.
""" """
norm_type = float(norm_type)
assert norm_type == 2.0
# Make sure that gradients have been reduced # Make sure that gradients have been reduced
self.grad_sync() self.grad_sync()
# Evaluate L2 norm of distributed gradients if not parameters or len(parameters) == self._num_grads:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') # Compute norm of all local gradients
grad_norm_sq = multi_tensor_applier( dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
amp_C.multi_tensor_l2norm, grad_norm_sq = multi_tensor_applier(
dummy_overflow_buf, amp_C.multi_tensor_l2norm,
[[bucket['grads_shard'] for bucket in self.state['buckets']]], dummy_overflow_buf,
False, [[bucket['grads_shard'] for bucket in self.state['buckets']]],
)[0] ** 2 False,
torch.distributed.all_reduce( )[0] ** 2
grad_norm_sq, else:
group=self.distributed_process_group, # Compute norm of selected local gradients
) grads = []
for param in parameters:
# If model parallelism is enabled, subtract non-parallel for fragment in self.state[param]['fragments']:
# gradients on non-root processes
if self.model_parallel and self.model_parallel_rank:
non_parallel_grads = []
for bucket in self.state['buckets']:
for fragment in bucket['fragments']:
if fragment['in_local_shard']: if fragment['in_local_shard']:
param_group_id = fragment['param_group_id'] bucket_id = fragment['bucket_id']
param_id = fragment['param_id'] bucket = self.state['buckets'][bucket_id]
param = self.param_groups[param_group_id]['params'][param_id] shard_start, shard_end = fragment['shard_range']
if (hasattr(param, 'model_parallel') grads.append(bucket['grads_shard'][shard_start:shard_end])
and not param.model_parallel): if grads:
shard_start, shard_end = fragment['shard_range']
non_parallel_grads.append(
bucket['grads_shard'][shard_start:shard_end]
)
if non_parallel_grads:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
non_parallel_grad_norm_sq = multi_tensor_applier( grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm, amp_C.multi_tensor_l2norm,
dummy_overflow_buf, dummy_overflow_buf,
[non_parallel_grads], [grads],
False, False,
)[0] ** 2 )[0] ** 2
else: else:
non_parallel_grad_norm_sq = torch.zeros([1], device=self.device) grad_norm_sq = torch.zeros([1], dtype=torch.float32, device=self.device)
return grad_norm_sq.detach().view([])
def grad_norm(self, parameters=[], norm_type=2.0, force=False):
"""Gradient norm of parameters in optimizer
The norm is computed over all gradients together, as if they
were concatenated into a single vector. All provided
parameters must be managed by optimizer.
The computed value is cached to avoid redundant communication.
Arguments:
parameters (iterable, optional): an iterable of parameters
in optimizer (default: all parameters in optimizer).
norm_type (float or int, optional): type of the used
p-norm (default: 2). Only 2-norm is currently
supported.
force (bool, optional): ignore cached value and force norm
computation (default: False).
"""
if force or self._grad_norm is None:
norm_type = float(norm_type)
assert norm_type == 2.0
grad_norm_sq = self._local_grad_norm(
parameters=parameters,
norm_type=norm_type,
)
torch.distributed.all_reduce( torch.distributed.all_reduce(
non_parallel_grad_norm_sq, grad_norm_sq,
op=torch.distributed.ReduceOp.SUM,
group=self.distributed_process_group, group=self.distributed_process_group,
) )
grad_norm_sq -= non_parallel_grad_norm_sq self._grad_norm = grad_norm_sq.sqrt()
return self._grad_norm.detach()
return grad_norm_sq.sqrt() def clip_grad_norm(self, max_norm, parameters=[], norm_type=2.0):
"""Clips gradient norm of parameters in optimizer
def step(self, closure=None, scale=1.): The norm is computed over all gradients together, as if they
were concatenated into a single vector. The scaling is
deferred until the optimizer step, which should be called
immediately after this function.
The computed grad norm is cached to avoid redundant
communication.
Arguments:
max_norm (float or int): max norm of the gradients
parameters (iterable, optional): an iterable of parameters
in optimizer (default: all parameters in optimizer).
norm_type (float or int, optional): type of the used
p-norm (default: 2)
"""
assert max_norm > 0
total_norm = self.grad_norm(parameters=parameters, norm_type=norm_type)
inv_clip_coef = (total_norm + 1e-6) / max_norm
self._inv_grad_scale = torch.clamp(inv_clip_coef, min=1.0).view(1)
return total_norm
def step(self, closure=None, *, grad_scaler=None):
"""Apply Adam optimizer step """Apply Adam optimizer step
Arguments: Arguments:
closure (callable, optional): closure to recompute loss closure (callable, optional): closure to recompute loss
(default: None) (default: None)
scale (float, optional): scaling factor to divide grad_scaler (torch.cuda.amp.GradScaler, optional):
gradients (default: 1.0) gradient scaler (default: None)
""" """
self.state['step'] += 1
# Apply closure
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
...@@ -729,14 +780,26 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -729,14 +780,26 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Make sure that gradients have been reduced # Make sure that gradients have been reduced
self.grad_sync() self.grad_sync()
# Scale gradient if L2 norm is too large # Apply gradient scaler if provided
if self.max_grad_norm > 0: # Note: We compute gradient norm to check for non-finite
grad_norm = self.grad_norm().item() # values. This is more conservative and compute intensive than
if (math.isfinite(grad_norm) # directly checking, but it avoids extra communication if we
and grad_norm / scale > self.max_grad_norm): # have already computed gradient norm e.g. for gradient
scale = grad_norm / self.max_grad_norm # clipping.
if grad_scaler is not None:
grad_norm = self.grad_norm()
found_inf = torch.logical_not(torch.isfinite(grad_norm))
scaler_state = grad_scaler._per_optimizer_states[id(self)]
scaler_state['found_inf_per_device'] = {found_inf.device: found_inf.float()}
if found_inf.item():
return
else:
assert grad_scaler._scale is not None
self._inv_grad_scale *= grad_scaler._scale
inv_grad_scale = self._inv_grad_scale.item()
# Apply optimizer step to each bucket and synchronize params # Apply optimizer step to each bucket and synchronize params
self.state['step'] += 1
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams: for stream in self._pipeline_streams:
stream.wait_stream(current_stream) stream.wait_stream(current_stream)
...@@ -806,14 +869,14 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -806,14 +869,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
eps, eps,
weight_decay, weight_decay,
group['lr'], group['lr'],
scale, inv_grad_scale,
self.state['step'], self.state['step'],
1, # Set to 0 to apply eps inside sqrt 1, # Set to 0 to apply eps inside sqrt
) )
del group_buffers
# Deallocate buffers # Deallocate buffers
del buffers del buffers
bucket['grads_shard'] = None
# Allgather updated parameters # Allgather updated parameters
if self.distributed_size == 1: if self.distributed_size == 1:
...@@ -842,24 +905,33 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -842,24 +905,33 @@ class DistributedFusedAdam(torch.optim.Optimizer):
del params_shard_copy del params_shard_copy
# Copy values to param buffers # Copy values to param buffers
params_in = [] buffers = collections.defaultdict(list) # param_in, param_out
params_out = []
for fragment in bucket['fragments']: for fragment in bucket['fragments']:
param_group_id = fragment['param_group_id'] param_group_id = fragment['param_group_id']
param_id = fragment['param_id'] param_id = fragment['param_id']
param = self.param_groups[param_group_id]['params'][param_id] param = self.param_groups[param_group_id]['params'][param_id]
bucket_start, bucket_end = fragment['bucket_range'] bucket_start, bucket_end = fragment['bucket_range']
param_start, param_end = fragment['param_range'] param_start, param_end = fragment['param_range']
params_in.append(params_bucket[bucket_start:bucket_end]) buffers[(param.is_cuda, param.dtype)].append((
params_out.append(param.view(-1)[param_start:param_end]) params_bucket[bucket_start:bucket_end],
if params_in: param.detach().view(-1)[param_start:param_end],
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') ))
multi_tensor_applier( for (is_cuda, dtype), dtype_buffers in buffers.items():
fused_adam_cuda.maybe_cast_mt, fused_kernel_dtypes = (torch.float32, torch.float16, torch.uint8)
dummy_overflow_buf, if (is_cuda
[params_in, params_out], and dtype in fused_kernel_dtypes
) and self.param_sync_dtype in fused_kernel_dtypes):
del params_bucket, params_in, params_out dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
dummy_overflow_buf,
list(zip(*dtype_buffers)),
)
else:
for param_in, param_out in dtype_buffers:
param_out.copy_(param_in)
del dtype_buffers
del params_bucket, buffers
# Synchronize pipeline streams # Synchronize pipeline streams
for stream in self._pipeline_streams: for stream in self._pipeline_streams:
......
...@@ -21,11 +21,17 @@ class SimpleModel(torch.nn.Module): ...@@ -21,11 +21,17 @@ class SimpleModel(torch.nn.Module):
y += (i+1) * l(x) y += (i+1) * l(x)
return y return y
def make_models(num_layers, size, dtype=torch.float32, overlap_communication=True): def make_models(
num_layers,
size,
dtype=torch.float32,
device='cuda',
overlap_communication=True,
):
# Construct models with same parameters # Construct models with same parameters
ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda') ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device)
dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda') dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device)
with torch.no_grad(): with torch.no_grad():
for ref_param, dist_param in zip(dist_model.parameters(), for ref_param, dist_param in zip(dist_model.parameters(),
ref_model.parameters()): ref_model.parameters()):
...@@ -35,22 +41,22 @@ def make_models(num_layers, size, dtype=torch.float32, overlap_communication=Tru ...@@ -35,22 +41,22 @@ def make_models(num_layers, size, dtype=torch.float32, overlap_communication=Tru
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
ref_model = torch.nn.parallel.DistributedDataParallel( ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model, ref_model,
device_ids=[rank], device_ids=[rank] if device=='cuda' else None,
output_device=rank, output_device=rank if device=='cuda' else None,
) )
# Construct optimizers with same hyperparameters # Construct optimizers with same hyperparameters
optim_args = { 'lr': 1, 'betas': (0.1,0.2), 'eps': 0.1, 'weight_decay': 0.1 } optim_args = dict(lr=0.1, betas=(0.1,0.2), eps=0.25, weight_decay=0.1)
ref_optim = torch.optim.AdamW( ref_optim = torch.optim.AdamW(
[ [
{'params': list(ref_model.parameters())[1::2], 'lr': 0.5}, {'params': list(ref_model.parameters())[1::2], 'lr': 0.2},
{'params': list(ref_model.parameters())[0::2]}, {'params': list(ref_model.parameters())[0::2]},
], ],
**optim_args, **optim_args,
) )
dist_optim = DistributedFusedAdam( dist_optim = DistributedFusedAdam(
[ [
{'params': list(dist_model.parameters())[1::2], 'lr': 0.5}, {'params': list(dist_model.parameters())[1::2], 'lr': 0.2},
{'params': list(dist_model.parameters())[0::2]}, {'params': list(dist_model.parameters())[0::2]},
], ],
overlap_grad_sync=overlap_communication, overlap_grad_sync=overlap_communication,
...@@ -81,8 +87,9 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -81,8 +87,9 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
overlap_communication=True, overlap_communication=True,
use_nosync=True, use_nosync=True,
dtype=torch.float32, dtype=torch.float32,
rtol=1e-5, device='cuda',
atol=1e-5, rtol=None,
atol=None,
): ):
torch.manual_seed(self.seed + self.rank) torch.manual_seed(self.seed + self.rank)
...@@ -92,6 +99,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -92,6 +99,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
num_layers, num_layers,
layer_size, layer_size,
dtype=dtype, dtype=dtype,
device=device,
overlap_communication=overlap_communication, overlap_communication=overlap_communication,
) )
...@@ -106,10 +114,10 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -106,10 +114,10 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
for micro_step in range(micro_batch_steps): for micro_step in range(micro_batch_steps):
# Synthetic data # Synthetic data
x = torch.rand(batch_size, layer_size) + 0.5 x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) + 0.5 dy = torch.rand_like(x) - 0.5
x = x.to(dtype=dtype, device='cuda') x = x.to(dtype=dtype, device=device)
dy = dy.to(dtype=dtype, device='cuda') dy = dy.to(dtype=dtype, device=device)
# Reference implementation # Reference implementation
x_ref = x.detach().clone().requires_grad_(True) x_ref = x.detach().clone().requires_grad_(True)
...@@ -136,8 +144,8 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -136,8 +144,8 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_optim.step() dist_optim.step()
# Check that parameters match # Check that parameters match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(), for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters())): dist_model.parameters()):
torch.testing.assert_close( torch.testing.assert_close(
dist_param, ref_param, rtol=rtol, atol=atol) dist_param, ref_param, rtol=rtol, atol=atol)
...@@ -150,6 +158,20 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -150,6 +158,20 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
def test_matches_pytorch_sync_every_step(self): def test_matches_pytorch_sync_every_step(self):
self.test_matches_pytorch(use_nosync=False) self.test_matches_pytorch(use_nosync=False)
def test_matches_pytorch_fp64(self):
self.test_matches_pytorch(
dtype=torch.float64,
rtol=1.3e-6,
atol=1e-5,
)
def test_matches_pytorch_fp16(self):
self.test_matches_pytorch(
dtype=torch.float16,
rtol=1e-2,
atol=1e-2,
)
def test_raises_on_mismatch(self): def test_raises_on_mismatch(self):
torch.manual_seed(self.seed + self.rank) torch.manual_seed(self.seed + self.rank)
...@@ -172,16 +194,89 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -172,16 +194,89 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_optim.step() dist_optim.step()
# Check that parameters do not match # Check that parameters do not match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(), for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters())): dist_model.parameters()):
self.assertRaises( self.assertRaises(
AssertionError, AssertionError,
torch.testing.assert_close, torch.testing.assert_close,
dist_param, ref_param, dist_param, ref_param,
rtol=1e-5,
atol=1e-5,
) )
def test_clip_grad_norm(self):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)
# Training steps with pre-determined gradients
xs = [3, 1, 4, 1, 5, 9]
dys = [1, -1, 1, -1, 1, -1]
for x, dy in zip(xs, dys):
x = torch.tensor([x], dtype=torch.float32, device='cuda')
dy = torch.tensor([dy], dtype=torch.float32, device='cuda')
# Reference implementation
ref_optim.zero_grad()
y_ref = ref_model(x.detach())
y_ref.backward(dy.detach())
ref_grad_norm = torch.nn.utils.clip_grad_norm_(ref_model.parameters(), 3.5)
ref_optim.step()
# Distributed implementation
dist_optim.zero_grad()
y_dist = dist_model(x.detach())
y_dist.backward(dy.detach())
dist_grad_norm = dist_optim.clip_grad_norm(3.5)
dist_optim.step()
# Check that parameters match
torch.testing.assert_close(dist_grad_norm, ref_grad_norm)
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
def test_grad_scaler(self):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)
grad_scaler_args = dict(
init_scale=3.21,
growth_factor=1.23,
backoff_factor=0.876,
growth_interval=1,
)
ref_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args)
dist_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args)
# Training steps with pre-determined gradients
xs = [3, 1, 4, 1, 5, 9]
dys = [1, float('inf'), 1, 1, float('nan'), -1]
for x, dy in zip(xs, dys):
x = torch.tensor([x], dtype=torch.float32, device='cuda')
dy = torch.tensor([dy], dtype=torch.float32, device='cuda')
# Reference implementation
ref_optim.zero_grad()
y_ref = ref_model(x.detach())
ref_scaler.scale(y_ref).backward(dy.detach())
ref_scaler.step(ref_optim)
ref_scaler.update()
# Distributed implementation
dist_optim.zero_grad()
y_dist = dist_model(x.detach())
dist_scaler.scale(y_dist).backward(dy.detach())
dist_scaler.step(dist_optim)
dist_scaler.update()
# Check that parameters match
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
if __name__ == "__main__": if __name__ == "__main__":
# Assume script has been run with torchrun # Assume script has been run with torchrun
common_utils.run_tests() common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment