Unverified Commit cd499737 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add features to distributed Adam for Megatron support (#1414)

* Add features to distributed Adam for Megatron support

Support gradient clipping, gradient scaling, FP32 grad accumulation, and multiple dtypes and devices.

* Restore closure arg to distributed Adam

Review suggestion from @crcrpar
parent bf3c008e
......@@ -9,7 +9,7 @@ import threading
import torch
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank
class DistributedFusedAdam(torch.optim.Optimizer):
"""AdamW optimizer with ZeRO algorithm.
......@@ -52,7 +52,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
parameter synchronization (default: same as
grad_sync_dtype)
device (torch.device, optional): device for optimizer state
(default: cuda). Currently only supports GPU.
(default: cuda). Currently only supports GPU with one GPU
per process.
process_group (torch.distributed.ProcessGroup, optional):
parallel processes participating in optimizer (default:
default group in torch.distributed). This group is
......@@ -64,10 +65,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
redundant_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to replicate optimizer state
over (default: group only containing calling process)
model_parallel (bool, optional): whether model parallelism is
used (default: False)
model_parallel_rank (int, optional): rank in model-parallel
process group (default: 0)
average_grad_sync (bool, optional): whether to use average
reduction for gradient synchronization rather than sum
(default: True)
......@@ -75,14 +72,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
gradient synchronization with backward pass compute
(default: True)
bucket_cap_mb (float, optional): bucket size in megabytes
(default: 15)
(default: 100)
pipeline_size (int, optional): number of buckets to
synchronize simultaneously (default: 2)
fused_grad_copy (bool, optional): whether to used fused kernel
to fill bucket with gradients (default: False). Requires
all parameters to have the same data type.
max_grad_norm (float, optional): maximum L2 norm for gradient
clipping (default: disabled)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
......@@ -105,6 +97,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Asynchronous reduction is in progress
SYNCING = enum.auto()
_step_supports_amp_scaling = True
def __init__(self,
params,
lr=1e-3,
......@@ -120,14 +114,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
process_group=None,
distributed_process_group=None,
redundant_process_group=None,
model_parallel=False,
model_parallel_rank=0,
average_grad_sync=True,
overlap_grad_sync=True,
bucket_cap_mb=100,
pipeline_size=2,
fused_grad_copy=False,
max_grad_norm=0.,
):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay)
......@@ -142,11 +132,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grad_sync_dtype = dtype
if param_sync_dtype is None:
param_sync_dtype = grad_sync_dtype
valid_dtypes = [
supported_dtypes = [
(torch.float32, torch.float16, torch.float16),
(torch.float32, torch.float32, torch.float32),
]
if (dtype, grad_sync_dtype, param_sync_dtype) not in valid_dtypes:
if (dtype, grad_sync_dtype, param_sync_dtype) not in supported_dtypes:
raise RuntimeError(
'Invalid dtypes for DistributedFusedAdam '
f'(dtype={dtype}, '
......@@ -160,18 +150,18 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.device = device
# Process groups
self.world_process_group = (
self.process_group = (
_get_default_group()
if process_group is None
else process_group
)
self.distributed_process_group = (
self.world_process_group
self.process_group
if distributed_process_group is None
else distributed_process_group
)
self.redundant_process_group = redundant_process_group
self.world_size = torch.distributed.get_world_size(self.world_process_group)
self.process_group_size = torch.distributed.get_world_size(self.process_group)
self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group)
self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group)
self.redundant_size = (
......@@ -179,34 +169,22 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self.redundant_process_group is None
else torch.distributed.get_world_size(self.redundant_process_group)
)
if (self.world_size != self.distributed_size * self.redundant_size):
if self.process_group_size != self.distributed_size * self.redundant_size:
raise RuntimeError(
'Invalid process group configuration '
f'(world process group size = {self.world_size}, '
f'(process group size = {self.process_group_size}, '
f'distributed process group size = {self.distributed_size}, '
f'redundant process group size = {self.redundant_size})'
)
self.model_parallel = model_parallel
self.model_parallel_rank = model_parallel_rank
# Grad sync options
if fused_grad_copy:
_params = list(self.parameters())
if (_params
and any(p.dtype != self.grad_sync_dtype for p in _params)
and any(p.device != self.device for p in _params)):
raise RuntimeError(
'Attempted to use fused gradient copy in DistributedFusedAdam, '
'but parameters do not all have expected '
f'dtype ({self.grad_sync_dtype}) and device ({self.device})'
)
# Use average reduction for grad sync
self.average_grad_sync = average_grad_sync
# Copy param grads to bucket as soon as available
self.greedy_grad_copy = True
# Synchronize grad buckets as soon as all grads are available
self.overlap_grad_sync = overlap_grad_sync
# Number of buckets to synchronize at a time
self.pipeline_size = pipeline_size
self.fused_grad_copy = fused_grad_copy
# Grad clipping options
self.max_grad_norm = max_grad_norm
# Determine bucket sizes
dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8
......@@ -230,9 +208,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Objects for gradient synchronization
self._grads_generated = set()
self._grads_to_copy = []
self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)]
# Divide gradients by factor before optimizer step. Used for
# grad clipping and gradient scaler.
self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
# Norm of parameter gradients. Used for gradient clipping and
# gradient scaler.
self._grad_norm = None
# Check if collectives have no_copy option
self._reduce_scatter_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
......@@ -254,12 +238,16 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_grads = 0
self._lock = threading.Lock()
self._grad_accs = []
try:
root_rank = _get_global_rank(self.process_group, 0)
except:
root_rank = 0
for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']):
torch.distributed.broadcast(
param,
src=0,
group=self.world_process_group,
src=root_rank,
group=self.process_group,
)
if param.requires_grad:
def wrapper(p, p_group_id, p_id):
......@@ -269,9 +257,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
with self._lock:
if 'fragments' not in self.state[p]:
self._init_param_state(p, p_group_id, p_id)
if self.overlap_grad_sync:
self._start_grad_copy(p)
self._try_start_bucket_grad_sync()
if self.greedy_grad_copy:
self._grad_copy(p)
if self.overlap_grad_sync:
self._try_start_bucket_grad_sync(
[p],
ignore_last_bucket=True,
)
grad_acc.register_hook(reduction_hook)
self._grad_accs.append(grad_acc)
wrapper(param, param_group_id, param_id)
......@@ -415,13 +407,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
bucket['curr_grads_shard'] = None
bucket['gradient_status'] = self.GradientStatus.READY
self._grads_generated = set()
self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
self._grad_norm = None
def _start_grad_copy(self, param):
"""Copy parameter gradient to corresponding buckets
The copy is deferred if using a fused copy kernel.
"""
def _grad_copy(self, param):
"""Copy parameter gradients to buckets"""
# Copy param grad to buckets
for fragment in self.state[param]['fragments']:
......@@ -447,62 +437,26 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Copy param grad to bucket
if param.grad is not None:
fragment_in = param.grad.view(-1)[grad_start:grad_end]
fragment_out = bucket['grads_bucket'][bucket_start:bucket_end]
self._grads_to_copy.append((fragment_in, fragment_out))
scale = 1/self.process_group_size if self.average_grad_sync else 1.0
grad_in = param.grad.detach().view(-1)[grad_start:grad_end]
grad_out = bucket['grads_bucket'][bucket_start:bucket_end]
grad_out.add_(grad_in, alpha=scale)
# Free param grad buffer
if not self.fused_grad_copy:
self._finish_grad_copy()
param.grad = None
# Update reduction statuses
self._grads_generated.add(param)
for fragment in self.state[param]['fragments']:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
is_filled = True
for other_fragment in reversed(bucket['fragments']):
param_group_id = other_fragment['param_group_id']
param_id = other_fragment['param_id']
other_param = self.param_groups[param_group_id]['params'][param_id]
if other_param not in self._grads_generated:
is_filled = False
break
if is_filled:
bucket['gradient_status'] = self.GradientStatus.FULLY_FILLED
def _finish_grad_copy(self):
"""Make sure that parameter gradients have been copied to buckets
Performs any deferred copies from _start_grad_copy.
"""
if self._grads_to_copy:
scale = 1/self.world_size if self.average_grad_sync else 1.0
if self.fused_grad_copy:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
list(zip(*self._grads_to_copy)),
scale,
)
else:
for fragment_in, fragment_out in self._grads_to_copy:
fragment_out.add_(fragment_in, alpha=scale)
self._grads_to_copy = []
def _force_bucket_grad_sync(self):
"""Ensure that all gradient buckets are synchronized"""
# Synchronize all unsynchronized buckets
self._finish_bucket_grad_sync()
self._start_bucket_grad_sync([
buckets = [
bucket for bucket in self.state['buckets']
if bucket['gradient_status'] != self.GradientStatus.READY
])
self._finish_bucket_grad_sync()
]
if buckets:
self._start_bucket_grad_sync(buckets)
self._finish_bucket_grad_sync()
# Fill any unfilled buckets with zeros
for bucket in self.state['buckets']:
......@@ -516,20 +470,54 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Reset set of generated gradients
self._grads_generated = set()
def _try_start_bucket_grad_sync(self):
def _try_start_bucket_grad_sync(
self,
params=[],
ignore_last_bucket=True,
):
"""Launches gradient synchronization if enough buckets are ready
Gradient synchronization is asynchronous. Launches gradient
synchronization if all gradients have been generated or if
there are enough buckets ready to fill pipeline.
Arguments:
params (iterable): parameters that have had their
gradients copied to buckets
ignore_last_bucket (bool): avoid synchronizing last bucket
until all gradients have been generated. This avoids
excessive synchronization when initializing buckets in
the first backward pass.
"""
# Register params that have generated grads
for param in params:
self._grads_generated.add(param)
for fragment in self.state[param]['fragments']:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
is_filled = True
for other_fragment in reversed(bucket['fragments']):
param_group_id = other_fragment['param_group_id']
param_id = other_fragment['param_id']
other_param = self.param_groups[param_group_id]['params'][param_id]
if other_param not in self._grads_generated:
is_filled = False
break
if is_filled:
bucket['gradient_status'] = self.GradientStatus.FULLY_FILLED
# Launch reductions if enough buckets are ready
if len(self._grads_generated) == self._num_grads:
self._force_bucket_grad_sync()
else:
all_buckets = self.state['buckets']
if ignore_last_bucket:
all_buckets = all_buckets[:-1]
filled_buckets = [
bucket
for bucket in self.state['buckets'][:-1]
for bucket in all_buckets
if bucket['gradient_status'] == self.GradientStatus.FULLY_FILLED
]
pipeline_size = (len(filled_buckets) // self.pipeline_size) * self.pipeline_size
......@@ -545,7 +533,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""
self._finish_bucket_grad_sync()
self._finish_grad_copy()
# Reduce gradients
for stream in self._pipeline_streams:
......@@ -622,8 +609,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Reset status
bucket['gradient_status'] = self.GradientStatus.READY
# Cached gradient norm has been invalidated
self._grad_norm = None
@contextlib.contextmanager
def no_sync(self):
def no_sync(self, greedy_grad_copy=False):
"""Disable overlapped gradient synchronization
Context manager that is similar to
......@@ -632,12 +622,21 @@ class DistributedFusedAdam(torch.optim.Optimizer):
overlapped gradient synchronization is enabled, gradients can
also be synchronized by leaving the context and performing a
backward pass.
Arguments:
greedy_grad_copy (bool, optional): copy parameter
gradients to buckets as soon as they are generated
(default: False)
"""
old_greedy_grad_copy = self.greedy_grad_copy
old_overlap_grad_sync = self.overlap_grad_sync
self.greedy_grad_copy = greedy_grad_copy
self.overlap_grad_sync = False
try:
yield
finally:
self.greedy_grad_copy = old_greedy_grad_copy
self.overlap_grad_sync = old_overlap_grad_sync
def grad_sync(self):
......@@ -648,80 +647,132 @@ class DistributedFusedAdam(torch.optim.Optimizer):
param_id = fragment['param_id']
param = self.param_groups[param_group_id]['params'][param_id]
if param.grad is not None:
self._start_grad_copy(param)
self._try_start_bucket_grad_sync()
self._grad_copy(param)
self._try_start_bucket_grad_sync(
[param],
ignore_last_bucket=False,
)
self._force_bucket_grad_sync()
def grad_norm(self):
"""Compute L2 norm of all parameter gradients
def _local_grad_norm(self, parameters=[], norm_type=2.0):
"""Local contribution to parameter gradient norm
If model parallelism is enabled, exclude non-parallel
gradients on non-root processes. This is Megatron-specific, so
should this logic be moved elsewhere?
Returns square of 2-norm. Other norms are not yet supported.
If no parameters are provided, the norm is computed for all
parameters in optimizer. Provided parameters are assumed to be
in optimizer.
"""
norm_type = float(norm_type)
assert norm_type == 2.0
# Make sure that gradients have been reduced
self.grad_sync()
# Evaluate L2 norm of distributed gradients
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[[bucket['grads_shard'] for bucket in self.state['buckets']]],
False,
)[0] ** 2
torch.distributed.all_reduce(
grad_norm_sq,
group=self.distributed_process_group,
)
# If model parallelism is enabled, subtract non-parallel
# gradients on non-root processes
if self.model_parallel and self.model_parallel_rank:
non_parallel_grads = []
for bucket in self.state['buckets']:
for fragment in bucket['fragments']:
if not parameters or len(parameters) == self._num_grads:
# Compute norm of all local gradients
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[[bucket['grads_shard'] for bucket in self.state['buckets']]],
False,
)[0] ** 2
else:
# Compute norm of selected local gradients
grads = []
for param in parameters:
for fragment in self.state[param]['fragments']:
if fragment['in_local_shard']:
param_group_id = fragment['param_group_id']
param_id = fragment['param_id']
param = self.param_groups[param_group_id]['params'][param_id]
if (hasattr(param, 'model_parallel')
and not param.model_parallel):
shard_start, shard_end = fragment['shard_range']
non_parallel_grads.append(
bucket['grads_shard'][shard_start:shard_end]
)
if non_parallel_grads:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
shard_start, shard_end = fragment['shard_range']
grads.append(bucket['grads_shard'][shard_start:shard_end])
if grads:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
non_parallel_grad_norm_sq = multi_tensor_applier(
grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[non_parallel_grads],
[grads],
False,
)[0] ** 2
else:
non_parallel_grad_norm_sq = torch.zeros([1], device=self.device)
grad_norm_sq = torch.zeros([1], dtype=torch.float32, device=self.device)
return grad_norm_sq.detach().view([])
def grad_norm(self, parameters=[], norm_type=2.0, force=False):
"""Gradient norm of parameters in optimizer
The norm is computed over all gradients together, as if they
were concatenated into a single vector. All provided
parameters must be managed by optimizer.
The computed value is cached to avoid redundant communication.
Arguments:
parameters (iterable, optional): an iterable of parameters
in optimizer (default: all parameters in optimizer).
norm_type (float or int, optional): type of the used
p-norm (default: 2). Only 2-norm is currently
supported.
force (bool, optional): ignore cached value and force norm
computation (default: False).
"""
if force or self._grad_norm is None:
norm_type = float(norm_type)
assert norm_type == 2.0
grad_norm_sq = self._local_grad_norm(
parameters=parameters,
norm_type=norm_type,
)
torch.distributed.all_reduce(
non_parallel_grad_norm_sq,
grad_norm_sq,
op=torch.distributed.ReduceOp.SUM,
group=self.distributed_process_group,
)
grad_norm_sq -= non_parallel_grad_norm_sq
self._grad_norm = grad_norm_sq.sqrt()
return self._grad_norm.detach()
return grad_norm_sq.sqrt()
def clip_grad_norm(self, max_norm, parameters=[], norm_type=2.0):
"""Clips gradient norm of parameters in optimizer
def step(self, closure=None, scale=1.):
The norm is computed over all gradients together, as if they
were concatenated into a single vector. The scaling is
deferred until the optimizer step, which should be called
immediately after this function.
The computed grad norm is cached to avoid redundant
communication.
Arguments:
max_norm (float or int): max norm of the gradients
parameters (iterable, optional): an iterable of parameters
in optimizer (default: all parameters in optimizer).
norm_type (float or int, optional): type of the used
p-norm (default: 2)
"""
assert max_norm > 0
total_norm = self.grad_norm(parameters=parameters, norm_type=norm_type)
inv_clip_coef = (total_norm + 1e-6) / max_norm
self._inv_grad_scale = torch.clamp(inv_clip_coef, min=1.0).view(1)
return total_norm
def step(self, closure=None, *, grad_scaler=None):
"""Apply Adam optimizer step
Arguments:
closure (callable, optional): closure to recompute loss
(default: None)
scale (float, optional): scaling factor to divide
gradients (default: 1.0)
grad_scaler (torch.cuda.amp.GradScaler, optional):
gradient scaler (default: None)
"""
self.state['step'] += 1
# Apply closure
loss = None
if closure is not None:
loss = closure()
......@@ -729,14 +780,26 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Make sure that gradients have been reduced
self.grad_sync()
# Scale gradient if L2 norm is too large
if self.max_grad_norm > 0:
grad_norm = self.grad_norm().item()
if (math.isfinite(grad_norm)
and grad_norm / scale > self.max_grad_norm):
scale = grad_norm / self.max_grad_norm
# Apply gradient scaler if provided
# Note: We compute gradient norm to check for non-finite
# values. This is more conservative and compute intensive than
# directly checking, but it avoids extra communication if we
# have already computed gradient norm e.g. for gradient
# clipping.
if grad_scaler is not None:
grad_norm = self.grad_norm()
found_inf = torch.logical_not(torch.isfinite(grad_norm))
scaler_state = grad_scaler._per_optimizer_states[id(self)]
scaler_state['found_inf_per_device'] = {found_inf.device: found_inf.float()}
if found_inf.item():
return
else:
assert grad_scaler._scale is not None
self._inv_grad_scale *= grad_scaler._scale
inv_grad_scale = self._inv_grad_scale.item()
# Apply optimizer step to each bucket and synchronize params
self.state['step'] += 1
current_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(current_stream)
......@@ -806,14 +869,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
eps,
weight_decay,
group['lr'],
scale,
inv_grad_scale,
self.state['step'],
1, # Set to 0 to apply eps inside sqrt
)
del group_buffers
# Deallocate buffers
del buffers
bucket['grads_shard'] = None
# Allgather updated parameters
if self.distributed_size == 1:
......@@ -842,24 +905,33 @@ class DistributedFusedAdam(torch.optim.Optimizer):
del params_shard_copy
# Copy values to param buffers
params_in = []
params_out = []
buffers = collections.defaultdict(list) # param_in, param_out
for fragment in bucket['fragments']:
param_group_id = fragment['param_group_id']
param_id = fragment['param_id']
param = self.param_groups[param_group_id]['params'][param_id]
bucket_start, bucket_end = fragment['bucket_range']
param_start, param_end = fragment['param_range']
params_in.append(params_bucket[bucket_start:bucket_end])
params_out.append(param.view(-1)[param_start:param_end])
if params_in:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
dummy_overflow_buf,
[params_in, params_out],
)
del params_bucket, params_in, params_out
buffers[(param.is_cuda, param.dtype)].append((
params_bucket[bucket_start:bucket_end],
param.detach().view(-1)[param_start:param_end],
))
for (is_cuda, dtype), dtype_buffers in buffers.items():
fused_kernel_dtypes = (torch.float32, torch.float16, torch.uint8)
if (is_cuda
and dtype in fused_kernel_dtypes
and self.param_sync_dtype in fused_kernel_dtypes):
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
dummy_overflow_buf,
list(zip(*dtype_buffers)),
)
else:
for param_in, param_out in dtype_buffers:
param_out.copy_(param_in)
del dtype_buffers
del params_bucket, buffers
# Synchronize pipeline streams
for stream in self._pipeline_streams:
......
......@@ -21,11 +21,17 @@ class SimpleModel(torch.nn.Module):
y += (i+1) * l(x)
return y
def make_models(num_layers, size, dtype=torch.float32, overlap_communication=True):
def make_models(
num_layers,
size,
dtype=torch.float32,
device='cuda',
overlap_communication=True,
):
# Construct models with same parameters
ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda')
dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda')
ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device)
dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device)
with torch.no_grad():
for ref_param, dist_param in zip(dist_model.parameters(),
ref_model.parameters()):
......@@ -35,22 +41,22 @@ def make_models(num_layers, size, dtype=torch.float32, overlap_communication=Tru
rank = torch.distributed.get_rank()
ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model,
device_ids=[rank],
output_device=rank,
device_ids=[rank] if device=='cuda' else None,
output_device=rank if device=='cuda' else None,
)
# Construct optimizers with same hyperparameters
optim_args = { 'lr': 1, 'betas': (0.1,0.2), 'eps': 0.1, 'weight_decay': 0.1 }
optim_args = dict(lr=0.1, betas=(0.1,0.2), eps=0.25, weight_decay=0.1)
ref_optim = torch.optim.AdamW(
[
{'params': list(ref_model.parameters())[1::2], 'lr': 0.5},
{'params': list(ref_model.parameters())[1::2], 'lr': 0.2},
{'params': list(ref_model.parameters())[0::2]},
],
**optim_args,
)
dist_optim = DistributedFusedAdam(
[
{'params': list(dist_model.parameters())[1::2], 'lr': 0.5},
{'params': list(dist_model.parameters())[1::2], 'lr': 0.2},
{'params': list(dist_model.parameters())[0::2]},
],
overlap_grad_sync=overlap_communication,
......@@ -81,8 +87,9 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
overlap_communication=True,
use_nosync=True,
dtype=torch.float32,
rtol=1e-5,
atol=1e-5,
device='cuda',
rtol=None,
atol=None,
):
torch.manual_seed(self.seed + self.rank)
......@@ -92,6 +99,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
num_layers,
layer_size,
dtype=dtype,
device=device,
overlap_communication=overlap_communication,
)
......@@ -106,10 +114,10 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
for micro_step in range(micro_batch_steps):
# Synthetic data
x = torch.rand(batch_size, layer_size) + 0.5
dy = torch.rand_like(x) + 0.5
x = x.to(dtype=dtype, device='cuda')
dy = dy.to(dtype=dtype, device='cuda')
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.to(dtype=dtype, device=device)
dy = dy.to(dtype=dtype, device=device)
# Reference implementation
x_ref = x.detach().clone().requires_grad_(True)
......@@ -136,8 +144,8 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_optim.step()
# Check that parameters match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
dist_model.parameters())):
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(
dist_param, ref_param, rtol=rtol, atol=atol)
......@@ -150,6 +158,20 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
def test_matches_pytorch_sync_every_step(self):
self.test_matches_pytorch(use_nosync=False)
def test_matches_pytorch_fp64(self):
self.test_matches_pytorch(
dtype=torch.float64,
rtol=1.3e-6,
atol=1e-5,
)
def test_matches_pytorch_fp16(self):
self.test_matches_pytorch(
dtype=torch.float16,
rtol=1e-2,
atol=1e-2,
)
def test_raises_on_mismatch(self):
torch.manual_seed(self.seed + self.rank)
......@@ -172,16 +194,89 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_optim.step()
# Check that parameters do not match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
dist_model.parameters())):
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
self.assertRaises(
AssertionError,
torch.testing.assert_close,
dist_param, ref_param,
rtol=1e-5,
atol=1e-5,
)
def test_clip_grad_norm(self):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)
# Training steps with pre-determined gradients
xs = [3, 1, 4, 1, 5, 9]
dys = [1, -1, 1, -1, 1, -1]
for x, dy in zip(xs, dys):
x = torch.tensor([x], dtype=torch.float32, device='cuda')
dy = torch.tensor([dy], dtype=torch.float32, device='cuda')
# Reference implementation
ref_optim.zero_grad()
y_ref = ref_model(x.detach())
y_ref.backward(dy.detach())
ref_grad_norm = torch.nn.utils.clip_grad_norm_(ref_model.parameters(), 3.5)
ref_optim.step()
# Distributed implementation
dist_optim.zero_grad()
y_dist = dist_model(x.detach())
y_dist.backward(dy.detach())
dist_grad_norm = dist_optim.clip_grad_norm(3.5)
dist_optim.step()
# Check that parameters match
torch.testing.assert_close(dist_grad_norm, ref_grad_norm)
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
def test_grad_scaler(self):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)
grad_scaler_args = dict(
init_scale=3.21,
growth_factor=1.23,
backoff_factor=0.876,
growth_interval=1,
)
ref_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args)
dist_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args)
# Training steps with pre-determined gradients
xs = [3, 1, 4, 1, 5, 9]
dys = [1, float('inf'), 1, 1, float('nan'), -1]
for x, dy in zip(xs, dys):
x = torch.tensor([x], dtype=torch.float32, device='cuda')
dy = torch.tensor([dy], dtype=torch.float32, device='cuda')
# Reference implementation
ref_optim.zero_grad()
y_ref = ref_model(x.detach())
ref_scaler.scale(y_ref).backward(dy.detach())
ref_scaler.step(ref_optim)
ref_scaler.update()
# Distributed implementation
dist_optim.zero_grad()
y_dist = dist_model(x.detach())
dist_scaler.scale(y_dist).backward(dy.detach())
dist_scaler.step(dist_optim)
dist_scaler.update()
# Check that parameters match
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
if __name__ == "__main__":
# Assume script has been run with torchrun
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment