Unverified Commit 2e025ab5 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Improvements in distributed Adam optimizer for Megatron (#1432)

* Improvements in distributed Adam optimizer for Megatron

Add option to allocate gradient buckets out of one large buffer. Add option to initialize params in user-provided order. Perform communication when saving optimizer state. Support param sync with any dtype.

* Style fixes in distributed Adam helper classes

Review suggestions from @crcrpar
parent fb21698e
......@@ -3,6 +3,7 @@ import contextlib
import enum
import importlib
import inspect
import io
import math
import threading
......@@ -11,6 +12,10 @@ import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank
def _round_to_multiple(number, multiple, round_up=True):
"""Assumes arguments are positive integers"""
return (number+multiple-1 if round_up else number) // multiple * multiple
class DistributedFusedAdam(torch.optim.Optimizer):
"""AdamW optimizer with ZeRO algorithm.
......@@ -49,8 +54,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grad_sync_dtype (torch.dtype, optional): datatype for gradient
synchronization (default: same as dtype)
param_sync_dtype (torch.dtype, optional): datatype for
parameter synchronization (default: same as
grad_sync_dtype)
parameter synchronization (default: same as dtype)
device (torch.device, optional): device for optimizer state
(default: cuda). Currently only supports GPU with one GPU
per process.
......@@ -75,6 +79,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
(default: 100)
pipeline_size (int, optional): number of buckets to
synchronize simultaneously (default: 2)
contiguous_grad_buffer (bool, optional): allocate gradient
buckets out of a large persistent buffer (default: False).
This allows individual parameter gradients to be accessed
externally (see grad_buffer_view function). It also
maximizes memory usage and may prevent overlapping
communication and compute.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
......@@ -86,6 +96,56 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""
class ParameterFragment:
"""Buffer ranges for a parameter fragment
Describes corresponding regions in parameter buffer and
parameter bucket.
"""
def __init__(
self,
param_group_id,
param_id,
bucket_id,
param_range,
bucket_range,
in_local_shard,
shard_range,
shard_bucket_range,
shard_param_range,
):
# Parameter group index
self.param_group_id = param_group_id
# Parameter index within parameter group
self.param_id = param_id
# Bucket index
self.bucket_id = bucket_id
# Range within flattened parameter buffer
self.param_range = param_range
# Range within bucket
self.bucket_range = bucket_range
# Whether fragment is in local shard of bucket
self.in_local_shard = in_local_shard
# Range within local shard
self.shard_range = shard_range
# Range of local fragment shard within bucket
self.shard_bucket_range = shard_bucket_range
# Range of local fragment shard within parameter
self.shard_param_range = shard_param_range
class StateBucket:
def __init__(self, shard_size, dtype, device):
"""Optimizer state for a bucket"""
# Buffer ranges corresponding to parameter fragments
self.fragments = []
# Local shard of parameters
self.params_shard = torch.zeros([shard_size], dtype=dtype, device=device)
# Local shard of first moment estimate
self.exp_avg_shard = torch.zeros([shard_size], dtype=dtype, device=device)
# Local shard of second moment estimate
self.exp_avg_sq_shard = torch.zeros([shard_size], dtype=dtype, device=device)
class GradientStatus(enum.Enum):
"""Status of gradients within a bucket"""
# Gradients are ready to use
......@@ -97,6 +157,26 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Asynchronous reduction is in progress
SYNCING = enum.auto()
class GradientBucket:
"""Gradient buffers and state for a bucket"""
def __init__(self):
# Local shard of gradients
self.grads_shard = None
# Local contribution to gradients
self.grads_bucket = None
# Buffer for gradient reduce-scatter
self.sync_grads_shard = None
# Status of gradients
self.status = DistributedFusedAdam.GradientStatus.READY
# Request object for asynchronous communication
self.sync_request = None
def sync_wait(self):
"""Wait for asynchronous communication to finish"""
if self.sync_request is not None:
self.sync_request.wait()
self.sync_request = None
_step_supports_amp_scaling = True
def __init__(self,
......@@ -118,6 +198,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
overlap_grad_sync=True,
bucket_cap_mb=100,
pipeline_size=2,
contiguous_grad_buffer=False,
):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay)
......@@ -131,12 +212,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if grad_sync_dtype is None:
grad_sync_dtype = dtype
if param_sync_dtype is None:
param_sync_dtype = grad_sync_dtype
param_sync_dtype = dtype
supported_dtypes = [
(torch.float32, torch.float16, torch.float16),
(torch.float32, torch.float32, torch.float32),
(torch.float32, torch.float16),
(torch.float32, torch.float32),
]
if (dtype, grad_sync_dtype, param_sync_dtype) not in supported_dtypes:
if (dtype, grad_sync_dtype) not in supported_dtypes:
raise RuntimeError(
'Invalid dtypes for DistributedFusedAdam '
f'(dtype={dtype}, '
......@@ -176,6 +257,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
f'distributed process group size = {self.distributed_size}, '
f'redundant process group size = {self.redundant_size})'
)
try:
self._process_group_ranks = [
_get_global_rank(self.process_group, local_rank)
for local_rank in range(self.distributed_size)
]
except:
self._process_group_ranks = list(range(self.distributed_size))
# Use average reduction for grad sync
self.average_grad_sync = average_grad_sync
......@@ -185,13 +273,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.overlap_grad_sync = overlap_grad_sync
# Number of buckets to synchronize at a time
self.pipeline_size = pipeline_size
# Allocate contiguous buffer for gradients
self.contiguous_grad_buffer = contiguous_grad_buffer
# Determine bucket sizes
dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8
self.alignment = 128 // dtype_size
bucket_size = 1024*1024*bucket_cap_mb / dtype_size
shard_size = bucket_size / self.distributed_size
shard_size = (int(shard_size) // self.alignment) * self.alignment
shard_size = int(bucket_size / self.distributed_size)
shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False)
shard_size = max(shard_size, self.alignment)
bucket_size = shard_size * self.distributed_size
self.bucket_size = bucket_size
......@@ -207,6 +297,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.state['step'] = 0
# Objects for gradient synchronization
self._grads_buckets = collections.defaultdict(self.GradientBucket)
self._grads_generated = set()
self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)]
......@@ -224,6 +315,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._all_gather_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
)
self._gather_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.gather).args
)
# Attach hooks for gradient synchronization
self._register_post_backward_hooks()
......@@ -236,20 +330,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
"""
self._num_grads = 0
grad_buffer_size = 0
self._lock = threading.Lock()
self._grad_accs = []
try:
root_rank = _get_global_rank(self.process_group, 0)
except:
root_rank = 0
for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']):
torch.distributed.broadcast(
param,
src=root_rank,
src=self._process_group_ranks[0],
group=self.process_group,
)
if param.requires_grad:
self._num_grads += 1
# Callback after gradient is generated
def wrapper(p, p_group_id, p_id):
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
......@@ -261,13 +355,57 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._grad_copy(p)
if self.overlap_grad_sync:
self._try_start_bucket_grad_sync(
[p],
params=[p],
ignore_last_bucket=True,
)
grad_acc.register_hook(reduction_hook)
self._grad_accs.append(grad_acc)
wrapper(param, param_group_id, param_id)
self._num_grads += 1
# Gradient size, with padding for alignment
grad_size = _round_to_multiple(param.numel(), self.alignment)
grad_buffer_size += grad_size
# Allocate contiguous gradient buffer if needed
if self.contiguous_grad_buffer:
grad_buffer_size = _round_to_multiple(
grad_buffer_size,
self.bucket_size,
)
self._grad_buffer = torch.zeros(
[grad_buffer_size],
dtype=self.dtype,
device=self.device,
)
def init_params(self, params=None):
"""Initialize optimizer state for parameters
Arguments:
params (iterable, optional): parameters to initialize
(default: all parameters)
"""
# Default cases
if isinstance(params, torch.Tensor):
params = [params]
elif params is None:
params = []
for group in self.param_groups:
params.extend(group['params'])
# Get indices corresponding to parameters
id_map = dict()
for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']):
id_map[param] = (param_group_id, param_id)
# Initialize parameters
for param in params:
if param in id_map and 'fragments' not in self.state[param]:
param_group_id, param_id = id_map[param]
self._init_param_state(param, param_group_id, param_id)
def _init_param_state(
self,
......@@ -279,7 +417,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Make sure there is at least one bucket
if not self.state['buckets']:
self._add_bucket()
self.state['buckets'].append(
self.StateBucket(self.shard_size, self.dtype, self.device)
)
# Split parameter values into fragments
# Note: Each fragment resides within a bucket
......@@ -289,29 +429,25 @@ class DistributedFusedAdam(torch.optim.Optimizer):
while param_start < param_size:
# Get current bucket
if not self.state['buckets']:
self._add_bucket()
bucket_id = len(self.state['buckets']) - 1
bucket = self.state['buckets'][bucket_id]
fragment_id = len(bucket['fragments'])
fragment_id = len(bucket.fragments)
# Determine fragment position within bucket
if fragment_id == 0:
bucket_start = 0
else:
bucket_start = bucket['fragments'][-1]['bucket_range'][1]
bucket_start = (
(bucket_start + self.alignment - 1)
// self.alignment
* self.alignment
) # Pad until fragment is aligned
_, bucket_start = bucket.fragments[-1].bucket_range
bucket_start = _round_to_multiple(bucket_start, self.alignment)
fragment_size = min(param_size-param_start, self.bucket_size-bucket_start)
param_end = param_start + fragment_size
bucket_end = bucket_start + fragment_size
# Create new bucket if current one is full
if fragment_size <= 0:
self._add_bucket()
self.state['buckets'].append(
self.StateBucket(self.shard_size, self.dtype, self.device)
)
continue
# Fragment position within local shard
......@@ -331,81 +467,55 @@ class DistributedFusedAdam(torch.optim.Optimizer):
shard_param_start, shard_param_end = None, None
# Record fragment info
fragment = {
# Parameter group index
'param_group_id': param_group_id,
# Parameter index within parameter group
'param_id': param_id,
# Bucket index
'bucket_id': bucket_id,
# Range within flattened parameter buffer
'param_range': (param_start,param_end),
# Range within bucket
'bucket_range': (bucket_start,bucket_end),
# Whether fragment is in local shard of bucket
'in_local_shard': in_local_shard,
# Range within local shard
'shard_range': (shard_start,shard_end),
# Range of local fragment shard within bucket
'shard_bucket_range': (shard_bucket_start,shard_bucket_end),
# Range of local fragment shard within parameter
'shard_param_range': (shard_param_start,shard_param_end),
}
# Record fragment info
fragment = self.ParameterFragment(
param_group_id=param_group_id,
param_id=param_id,
bucket_id=bucket_id,
param_range=(param_start,param_end),
bucket_range=(bucket_start,bucket_end),
in_local_shard=in_local_shard,
shard_range=(shard_start,shard_end),
shard_bucket_range=(shard_bucket_start,shard_bucket_end),
shard_param_range=(shard_param_start,shard_param_end),
)
self.state[param]['fragments'].append(fragment)
bucket['fragments'].append(fragment)
bucket.fragments.append(fragment)
param_start = param_end
# Initialize master param buffer
for fragment in self.state[param]['fragments']:
if fragment['in_local_shard']:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
param_start, param_end = fragment['shard_param_range']
shard_start, shard_end = fragment['shard_range']
if fragment.in_local_shard:
bucket = self.state['buckets'][fragment.bucket_id]
param_start, param_end = fragment.shard_param_range
shard_start, shard_end = fragment.shard_range
model_param_fragment = param.view(-1)[param_start:param_end]
master_param_fragment = bucket['params_shard'][shard_start:shard_end]
master_param_fragment = bucket.params_shard[shard_start:shard_end]
master_param_fragment.copy_(model_param_fragment)
def _add_bucket(self):
"""Construct a bucket for optimizer state"""
self.state['buckets'].append({
# Parameter fragments associated with bucket
'fragments': [],
# Gradient buffers
'grads_shard': None,
'grads_bucket': None,
'curr_grads_shard': None, # For current micro-batch
# Optimizer state
'params_shard': torch.zeros([self.shard_size], dtype=self.dtype, device=self.device),
'exp_avg_shard': torch.zeros([self.shard_size], dtype=self.dtype, device=self.device),
'exp_avg_sq_shard': torch.zeros([self.shard_size], dtype=self.dtype, device=self.device),
# Status of parameter gradients
'gradient_status': self.GradientStatus.READY,
def zero_grad(self, set_to_none=True):
"""Clear parameter gradients"""
# Distributed request object for gradient synchronization
'grad_sync_request': None,
# Reset bucket buffers
self._grads_buckets.clear()
})
# Construct views into contiguous grad buffer, if needed
if self.contiguous_grad_buffer:
self._grad_buffer.zero_()
for bucket_id in range(len(self.state['buckets'])):
bucket_start = bucket_id * self.bucket_size
bucket_end = bucket_start + self.bucket_size
bucket = self._grads_buckets[bucket_id]
bucket.grads_bucket = self._grad_buffer[bucket_start:bucket_end]
def zero_grad(self, set_to_none=True):
"""Clear parameter gradients"""
# Reset param grads
for group in self.param_groups:
for param in group['params']:
if param.grad is None or set_to_none:
param.grad = None
else:
param.grad.zero_()
for bucket in self.state['buckets']:
bucket['grads_shard'] = None
bucket['grads_bucket'] = None
bucket['curr_grads_shard'] = None
bucket['gradient_status'] = self.GradientStatus.READY
# Reset other state
self._grads_generated = set()
self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
self._grad_norm = None
......@@ -417,51 +527,80 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for fragment in self.state[param]['fragments']:
# Get fragment position
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
grad_start, grad_end = fragment['param_range']
bucket_start, bucket_end = fragment['bucket_range']
bucket_id = fragment.bucket_id
bucket = self._grads_buckets[bucket_id]
grad_start, grad_end = fragment.param_range
bucket_start, bucket_end = fragment.bucket_range
# Set reduction status
if bucket['gradient_status'] == self.GradientStatus.SYNCING:
if bucket.status == self.GradientStatus.SYNCING:
self._finish_bucket_grad_sync()
bucket['gradient_status'] = self.GradientStatus.PARTIALLY_FILLED
bucket.status = self.GradientStatus.PARTIALLY_FILLED
# Allocate gradient buffer if needed
if bucket['grads_bucket'] is None:
bucket['grads_bucket'] = torch.zeros(
[self.bucket_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
if bucket.grads_bucket is None:
if self.contiguous_grad_buffer:
grad_buffer_start = bucket_id * self.bucket_size
grad_buffer_end = grad_buffer_start + self.bucket_size
bucket.grads_bucket = self._grad_buffer[grad_buffer_start:grad_buffer_end]
else:
bucket.grads_bucket = torch.empty(
[self.bucket_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
bucket.grads_bucket.zero_()
# Copy param grad to bucket
if param.grad is not None:
scale = 1/self.process_group_size if self.average_grad_sync else 1.0
grad_in = param.grad.detach().view(-1)[grad_start:grad_end]
grad_out = bucket['grads_bucket'][bucket_start:bucket_end]
grad_out.add_(grad_in, alpha=scale)
grad_out = bucket.grads_bucket[bucket_start:bucket_end]
if grad_in.data_ptr() != grad_out.data_ptr():
grad_out.add_(grad_in)
# Free param grad buffer
param.grad = None
def grad_buffer_view(self, param):
"""Construct view into grad buffer corresponding to param
Assumes optimizer is using a contiguous grad buffer.
"""
assert self.contiguous_grad_buffer
# Figure out corresponding position in grad buffer
param_fragments = self.state[param]['fragments']
start_bucket_id = param_fragments[0].bucket_id
start_bucket_offset, _ = param_fragments[0].bucket_range
end_bucket_id = param_fragments[-1].bucket_id
_, end_bucket_offset = param_fragments[-1].bucket_range
buffer_start = start_bucket_id * self.bucket_size + start_bucket_offset
buffer_end = end_bucket_id * self.bucket_size + end_bucket_offset
# Construct view into grad buffer
flat_buffer = self._grad_buffer[buffer_start:buffer_end]
return flat_buffer.detach().view(param.size())
def _force_bucket_grad_sync(self):
"""Ensure that all gradient buckets are synchronized"""
# Synchronize all unsynchronized buckets
self._finish_bucket_grad_sync()
buckets = [
bucket for bucket in self.state['buckets']
if bucket['gradient_status'] != self.GradientStatus.READY
bucket
for bucket_id, bucket in sorted(self._grads_buckets.items())
if bucket.status != self.GradientStatus.READY
]
if buckets:
self._start_bucket_grad_sync(buckets)
self._finish_bucket_grad_sync()
# Fill any unfilled buckets with zeros
for bucket in self.state['buckets']:
if bucket['grads_shard'] is None:
bucket['grads_shard'] = torch.zeros(
# Fill any unsynchronized gradients with zeros
for bucket_id in range(len(self.state['buckets'])):
bucket = self._grads_buckets[bucket_id]
if bucket.grads_shard is None:
bucket.grads_shard = torch.zeros(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
......@@ -495,75 +634,91 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for param in params:
self._grads_generated.add(param)
for fragment in self.state[param]['fragments']:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
bucket_id = fragment.bucket_id
bucket_fragments = self.state['buckets'][bucket_id].fragments
is_filled = True
for other_fragment in reversed(bucket['fragments']):
param_group_id = other_fragment['param_group_id']
param_id = other_fragment['param_id']
for other_fragment in reversed(bucket_fragments):
param_group_id = other_fragment.param_group_id
param_id = other_fragment.param_id
other_param = self.param_groups[param_group_id]['params'][param_id]
if other_param not in self._grads_generated:
is_filled = False
break
if is_filled:
bucket['gradient_status'] = self.GradientStatus.FULLY_FILLED
bucket = self._grads_buckets[bucket_id]
bucket.status = self.GradientStatus.FULLY_FILLED
# Launch reductions if enough buckets are ready
if len(self._grads_generated) == self._num_grads:
self._force_bucket_grad_sync()
else:
all_buckets = self.state['buckets']
if ignore_last_bucket:
all_buckets = all_buckets[:-1]
filled_buckets = [
bucket
for bucket in all_buckets
if bucket['gradient_status'] == self.GradientStatus.FULLY_FILLED
]
pipeline_size = (len(filled_buckets) // self.pipeline_size) * self.pipeline_size
filled_buckets = []
for bucket_id, bucket in sorted(self._grads_buckets.items()):
if ignore_last_bucket and bucket_id == len(self.state['buckets'])-1:
continue
if bucket.status == self.GradientStatus.FULLY_FILLED:
filled_buckets.append(bucket)
pipeline_size = _round_to_multiple(
len(filled_buckets),
self.pipeline_size,
)
if pipeline_size > 0:
self._start_bucket_grad_sync(filled_buckets[:pipeline_size])
def _start_bucket_grad_sync(self, buckets):
"""Synchronize gradients in buckets
"""Synchronize gradient buckets
Gradient synchronization is asynchronous. Involves
reduce-scatter over distributed process group and allreduce
over redundant process group.
"""
# Call recursively if more buckets than streams
while len(buckets) > self.pipeline_size:
self._start_bucket_grad_sync(buckets[:self.pipeline_size])
buckets = buckets[self.pipeline_size:]
self._finish_bucket_grad_sync()
# Reduction operation
if self.average_grad_sync:
reduce_op = torch.distributed.ReduceOp.AVG
else:
reduce_op = torch.distributed.ReduceOp.SUM
# Reduce gradients
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(torch.cuda.current_stream())
stream.wait_stream(main_stream)
for i, bucket in enumerate(buckets):
bucket['gradient_status'] = self.GradientStatus.SYNCING
bucket.status = self.GradientStatus.SYNCING
stream = self._pipeline_streams[i % self.pipeline_size]
with torch.cuda.stream(stream):
# Reduce-scatter over distributed process group
bucket.sync_wait()
if self.distributed_size == 1:
bucket['curr_grads_shard'] = bucket['grads_bucket']
bucket['grad_sync_request'] = None
bucket.sync_grads_shard = bucket.grads_bucket
else:
bucket['curr_grads_shard'] = torch.zeros(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
with torch.cuda.stream(main_stream):
bucket.sync_grads_shard = torch.zeros(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
grads_bucket_shards = [
bucket['grads_bucket'][i*self.shard_size:(i+1)*self.shard_size]
bucket.grads_bucket[i*self.shard_size:(i+1)*self.shard_size]
for i in range(self.distributed_size)
]
if self._reduce_scatter_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
bucket['grad_sync_request'] = (
bucket.sync_request = (
torch.distributed.reduce_scatter(
bucket['curr_grads_shard'],
bucket.sync_grads_shard,
grads_bucket_shards,
op=reduce_op,
group=self.distributed_process_group,
async_op=True,
**no_copy_kwarg,
......@@ -576,11 +731,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# submitted in a consistent order. There could be race
# conditions if wait doesn't finish in order.
if self.redundant_size > 1:
if bucket['grad_sync_request'] is not None:
bucket['grad_sync_request'].wait()
bucket['grad_sync_request'] = (
bucket.sync_wait()
bucket.sync_request = (
torch.distributed.all_reduce(
bucket['curr_grads_shard'],
bucket.sync_grads_shard,
op=reduce_op,
group=self.redundant_process_group,
async_op=True,
)
......@@ -588,26 +743,22 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _finish_bucket_grad_sync(self):
"""Wait for any gradient synchronizations that are in progress"""
for bucket in self.state['buckets']:
if bucket['gradient_status'] == self.GradientStatus.SYNCING:
for bucket_id, bucket in sorted(self._grads_buckets.items()):
if bucket.status == self.GradientStatus.SYNCING:
# Finish asynchronous communication
if bucket['grad_sync_request'] is not None:
bucket['grad_sync_request'].wait()
bucket['grad_sync_request'] = None
bucket.sync_wait()
# Accumulate gradient in local shard
if bucket['grads_shard'] is None:
bucket['grads_shard'] = bucket['curr_grads_shard']
if bucket.grads_shard is None:
bucket.grads_shard = bucket.sync_grads_shard
else:
bucket['grads_shard'].add_(bucket['curr_grads_shard'])
# Deallocate buffers for gradient synchronization
bucket['grads_bucket'] = None
bucket['curr_grads_shard'] = None
bucket.grads_shard.add_(bucket.sync_grads_shard)
bucket.grads_bucket = None
bucket.sync_grads_shard = None
# Reset status
bucket['gradient_status'] = self.GradientStatus.READY
bucket.status = self.GradientStatus.READY
# Cached gradient norm has been invalidated
self._grad_norm = None
......@@ -642,14 +793,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def grad_sync(self):
"""Ensure that all gradients are synchronized"""
for bucket in self.state['buckets']:
for fragment in bucket['fragments']:
param_group_id = fragment['param_group_id']
param_id = fragment['param_id']
for fragment in bucket.fragments:
param_group_id = fragment.param_group_id
param_id = fragment.param_id
param = self.param_groups[param_group_id]['params'][param_id]
if param.grad is not None:
self._grad_copy(param)
self._try_start_bucket_grad_sync(
[param],
params=[param],
ignore_last_bucket=False,
)
self._force_bucket_grad_sync()
......@@ -676,7 +827,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[[bucket['grads_shard'] for bucket in self.state['buckets']]],
[[bucket.grads_shard for bucket in self._grads_buckets.values()]],
False,
)[0] ** 2
else:
......@@ -684,11 +835,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grads = []
for param in parameters:
for fragment in self.state[param]['fragments']:
if fragment['in_local_shard']:
bucket_id = fragment['bucket_id']
bucket = self.state['buckets'][bucket_id]
shard_start, shard_end = fragment['shard_range']
grads.append(bucket['grads_shard'][shard_start:shard_end])
if fragment.in_local_shard:
bucket = self._grads_buckets[fragment.bucket_id]
shard_start, shard_end = fragment.shard_range
grads.append(bucket.grads_shard[shard_start:shard_end])
if grads:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
grad_norm_sq = multi_tensor_applier(
......@@ -798,45 +948,79 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._inv_grad_scale *= grad_scaler._scale
inv_grad_scale = self._inv_grad_scale.item()
# Construct workspace buffers
params_bucket_buffers = [
torch.empty(
[self.bucket_size],
dtype=self.param_sync_dtype,
device=self.device,
)
for _ in range(self.pipeline_size)
]
if self.grad_sync_dtype == self.param_sync_dtype:
shard_start = self.distributed_rank * self.shard_size
shard_end = shard_start + self.shard_size
params_copy_buffers = [
params_bucket[shard_start:shard_end]
for params_bucket in params_bucket_buffers
]
else:
params_copy_buffers = [
torch.empty(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
for _ in range(self.pipeline_size)
]
# Apply optimizer step to each bucket and synchronize params
self.state['step'] += 1
current_stream = torch.cuda.current_stream()
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(current_stream)
for i, bucket in enumerate(self.state['buckets']):
stream = self._pipeline_streams[i % self.pipeline_size]
stream.wait_stream(main_stream)
for bucket_id in range(len(self.state['buckets'])):
stream_id = bucket_id % self.pipeline_size
# Bucket buffers
fragments = self.state['buckets'][bucket_id].fragments
shard_start = self.distributed_rank * self.shard_size
shard_end = shard_start + self.shard_size
params_bucket = params_bucket_buffers[stream_id]
params_bucket_shard = params_bucket[shard_start:shard_end]
params_shard = self.state['buckets'][bucket_id].params_shard
params_copy = params_copy_buffers[stream_id]
exp_avg = self.state['buckets'][bucket_id].exp_avg_shard
exp_avg_sq = self.state['buckets'][bucket_id].exp_avg_sq_shard
grads = self._grads_buckets[bucket_id].grads_shard
# Perform compute on parallel stream
stream = self._pipeline_streams[stream_id]
with torch.cuda.stream(stream):
# Buffer for param sync
params_shard_copy = torch.zeros(
[self.shard_size],
dtype=self.param_sync_dtype,
device=self.device,
)
# Find param fragments in local shard
buffers = collections.defaultdict(list) # p, m, v, g, p_copy
for fragment in bucket['fragments']:
if fragment['in_local_shard']:
param_group_id = fragment['param_group_id']
shard_start, shard_end = fragment['shard_range']
for fragment in fragments:
if fragment.in_local_shard:
param_group_id = fragment.param_group_id
shard_start, shard_end = fragment.shard_range
buffers[param_group_id].append([
bucket['params_shard'][shard_start:shard_end],
bucket['exp_avg_shard'][shard_start:shard_end],
bucket['exp_avg_sq_shard'][shard_start:shard_end],
bucket['grads_shard'][shard_start:shard_end],
params_shard_copy[shard_start:shard_end],
params_shard[shard_start:shard_end],
exp_avg[shard_start:shard_end],
exp_avg_sq[shard_start:shard_end],
grads[shard_start:shard_end],
params_copy[shard_start:shard_end],
])
# Fuse param fragments if possible
if len(buffers) == 1:
group_id = list(buffers.keys())[0]
buffers[group_id] = [(
bucket['params_shard'],
bucket['exp_avg_shard'],
bucket['exp_avg_sq_shard'],
bucket['grads_shard'],
params_shard_copy,
params_shard,
exp_avg,
exp_avg_sq,
grads,
params_copy,
)]
# Apply optimizer step to each param group
......@@ -873,54 +1057,53 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.state['step'],
1, # Set to 0 to apply eps inside sqrt
)
del group_buffers
# Deallocate buffers
del buffers
# Cast parameter dtype if needed
if params_copy.data_ptr() != params_bucket_shard.data_ptr():
params_bucket_shard.copy_(params_copy)
# Allgather updated parameters
if self.distributed_size == 1:
params_bucket = params_shard_copy
else:
params_bucket = torch.zeros(
[self.bucket_size],
dtype=self.param_sync_dtype,
device=self.device,
)
params_bucket_shards = [
if self.distributed_size > 1:
all_params_bucket_shards = [
params_bucket[i*self.shard_size:(i+1)*self.shard_size]
for i in range(self.distributed_size)
]
params_bucket_shards[self.distributed_rank].copy_(params_shard_copy)
if self._all_gather_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
torch.distributed.all_gather(
params_bucket_shards,
params_bucket_shards[self.distributed_rank],
all_params_bucket_shards,
params_bucket_shard,
group=self.distributed_process_group,
**no_copy_kwarg,
)
del params_shard_copy
# Copy values to param buffers
buffers = collections.defaultdict(list) # param_in, param_out
for fragment in bucket['fragments']:
param_group_id = fragment['param_group_id']
param_id = fragment['param_id']
for fragment in fragments:
param_group_id = fragment.param_group_id
param_id = fragment.param_id
param = self.param_groups[param_group_id]['params'][param_id]
bucket_start, bucket_end = fragment['bucket_range']
param_start, param_end = fragment['param_range']
buffers[(param.is_cuda, param.dtype)].append((
params_bucket[bucket_start:bucket_end],
param.detach().view(-1)[param_start:param_end],
))
bucket_start, bucket_end = fragment.bucket_range
param_start, param_end = fragment.param_range
param_in = params_bucket[bucket_start:bucket_end]
param_out = param.detach().view(-1)[param_start:param_end]
if param_in.dtype == param_out.dtype:
# Just copy bytes if buffers have same type
param_in = param_in.view(torch.uint8)
param_out = param_out.view(torch.uint8)
buffers[(param.is_cuda, param.dtype)].append(
(param_in, param_out)
)
for (is_cuda, dtype), dtype_buffers in buffers.items():
fused_kernel_dtypes = (torch.float32, torch.float16, torch.uint8)
if (is_cuda
and dtype in fused_kernel_dtypes
and self.param_sync_dtype in fused_kernel_dtypes):
fused_kernel_dtypes = (
self.param_sync_dtype,
torch.float32,
torch.float16,
torch.uint8,
)
if is_cuda and dtype in fused_kernel_dtypes:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
......@@ -930,11 +1113,168 @@ class DistributedFusedAdam(torch.optim.Optimizer):
else:
for param_in, param_out in dtype_buffers:
param_out.copy_(param_in)
del dtype_buffers
del params_bucket, buffers
# Synchronize pipeline streams
for stream in self._pipeline_streams:
current_stream.wait_stream(stream)
main_stream.wait_stream(stream)
return loss
def state_dict(self, gather_on_root=True):
"""Get dictionary containing optimizer state
Default behavior is to perform communication so that the
entire optimizer state is returned on the root rank in the
process group. In this case, all ranks in the process group
must enter this function and no value is returned on non-root
ranks.
Arguments:
gather_on_root (bool, optional): Gather state from all
ranks on the root rank (default: True)
"""
state_dict = super().state_dict()
if not gather_on_root:
return state_dict
# Export local state to byte string
state_bytes = io.BytesIO()
torch.save(state_dict, state_bytes)
state_bytes.seek(0)
state_bytes_view = state_bytes.getbuffer()
# Get data sizes on all ranks
local_state_size = len(state_bytes_view)
state_sizes = [None] * self.distributed_size
torch.distributed.all_gather_object(
state_sizes,
local_state_size,
group=self.process_group,
)
max_state_size = max(state_sizes)
# Construct workspace buffers
chunk_size = self.shard_size * torch.finfo(self.grad_sync_dtype).bits // 8
if self.distributed_rank == 0:
gathered_state_bytes = [state_bytes.getvalue()]
gathered_state_bytes.extend(bytearray(size) for size in state_sizes[1:])
gathered_chunks_buffers = [
torch.empty(
[chunk_size * self.distributed_size],
dtype=torch.uint8,
device=self.device,
)
for _ in range(self.pipeline_size)
]
else:
chunk_buffers = [
torch.empty(
[chunk_size],
dtype=torch.uint8,
device=self.device,
)
for _ in range(self.pipeline_size)
]
# Split data into chunks and gather on root rank
# Note: Assuming we are using the NCCL backend, communication
# must happen on the GPU. We split the data into fixed-size
# chunks so that the GPU memory usage is limited to
# (chunk_size * distributed_size) bytes.
# TODO: Avoid chunking with direct communication between CPUs
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(main_stream)
for stream_id, offset in enumerate(range(0, max_state_size, chunk_size)):
stream_id %= self.pipeline_size
# Buffers for chunk
if self.distributed_rank == 0:
gathered_chunks = [
gathered_chunks_buffers[stream_id][i*chunk_size:(i+1)*chunk_size]
for i in range(self.distributed_size)
]
else:
chunk = chunk_buffers[stream_id]
# Perform communication on parallel stream
stream = self._pipeline_streams[stream_id]
with torch.cuda.stream(stream):
# Copy to GPU
if self.distributed_rank != 0 and offset < local_state_size:
local_chunk_size = min(chunk_size, local_state_size-offset)
chunk[:local_chunk_size].copy_(
torch.frombuffer(
state_bytes_view,
dtype=torch.uint8,
count=local_chunk_size,
offset=offset,
),
non_blocking=True,
)
# Gather on root
if self.distributed_rank == 0:
if self._gather_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
torch.distributed.gather(
gathered_chunks[0],
gathered_chunks,
dst=self._process_group_ranks[0],
group=self.process_group,
**no_copy_kwarg,
)
else:
torch.distributed.gather(
chunk,
dst=self._process_group_ranks[0],
group=self.process_group,
)
# Copy back to CPU
if self.distributed_rank == 0:
for rank in range(1, self.distributed_size):
if offset < state_sizes[rank]:
rank_chunk_size = min(chunk_size, state_sizes[rank]-offset)
torch.frombuffer(
gathered_state_bytes[rank],
dtype=torch.uint8,
count=rank_chunk_size,
offset=offset,
).copy_(
gathered_chunks[rank][:rank_chunk_size],
non_blocking=True,
)
# Synchronize GPU
for stream in self._pipeline_streams:
main_stream.wait_stream(stream)
main_stream.synchronize()
# Return gathered state data on root rank
if self.distributed_rank == 0:
return {'gathered_states': gathered_state_bytes}
else:
return None
def load_state_dict(self, state_dict):
"""Load optimizer state"""
# State dict contains state for all ranks
if 'gathered_states' in state_dict:
# Deallocate distributed optimizer state to reduce GPU
# memory usage
if 'buckets' in self.state:
del self.state['buckets']
# Get state for current rank and parse byte string
state_bytes = state_dict['gathered_states'][self.distributed_rank]
state_bytes = io.BytesIO(state_bytes)
state_dict = torch.load(state_bytes)
return super().load_state_dict(state_dict)
from contextlib import contextmanager
import io
import os
import torch
......@@ -25,6 +26,7 @@ def make_models(
num_layers,
size,
dtype=torch.float32,
param_sync_dtype=None,
device='cuda',
overlap_communication=True,
):
......@@ -61,6 +63,8 @@ def make_models(
],
overlap_grad_sync=overlap_communication,
bucket_cap_mb=71/(4*1024*1024),
dtype=torch.float32,
param_sync_dtype=param_sync_dtype,
**optim_args,
)
......@@ -87,6 +91,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
overlap_communication=True,
use_nosync=True,
dtype=torch.float32,
param_sync_dtype=None,
device='cuda',
rtol=None,
atol=None,
......@@ -99,6 +104,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
num_layers,
layer_size,
dtype=dtype,
param_sync_dtype=param_sync_dtype,
device=device,
overlap_communication=overlap_communication,
)
......@@ -172,6 +178,14 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
atol=1e-2,
)
def test_matches_pytorch_allgather_fp16(self):
self.test_matches_pytorch(
dtype=torch.float32,
param_sync_dtype=torch.float16,
rtol=1e-2,
atol=1e-2,
)
def test_raises_on_mismatch(self):
torch.manual_seed(self.seed + self.rank)
......@@ -277,6 +291,101 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
def test_checkpoint(self):
# Construct two models with same config and different params
num_layers = 5
layer_size = 2
torch.manual_seed(self.seed + self.rank)
_, _, model_save, optim_save = make_models(num_layers, layer_size)
_, _, model_load, optim_load = make_models(num_layers, layer_size)
# Train one of the models
num_steps = 3
micro_batch_steps = 2
batch_size = 4
for step in range(num_steps):
optim_save.zero_grad()
for micro_step in range(micro_batch_steps):
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.cuda()
dy = dy.cuda()
y = model_save(x)
y.backward(dy)
optim_save.step()
# Make sure models are different
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
self.assertRaises(
AssertionError,
torch.testing.assert_close,
param_load, param_save,
)
# Save state on root rank and load on all ranks
state_dict = {
'model': model_save.state_dict(),
'optim': optim_save.state_dict(),
}
if self.rank == 0:
state_bytes = io.BytesIO()
torch.save(state_dict, state_bytes)
state_bytes = [state_bytes.getvalue()]
else:
state_bytes = [None]
torch.distributed.broadcast_object_list(state_bytes, src=0)
state_bytes = io.BytesIO(state_bytes[0])
state_dict = torch.load(state_bytes, map_location='cuda')
model_load.load_state_dict(state_dict['model'])
optim_load.load_state_dict(state_dict['optim'])
# Make sure models are identical
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
torch.testing.assert_close(param_load, param_save)
# Train both models
num_steps = 3
micro_batch_steps = 3
batch_size = 5
for step in range(num_steps):
# Reset gradients
optim_save.zero_grad()
optim_load.zero_grad()
# Forward and backward passes
for micro_step in range(micro_batch_steps):
# Synthetic data
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.cuda()
dy = dy.cuda()
# Forward and backward pass
x_save = x.detach().clone().requires_grad_(True)
y_save = model_save(x_save)
y_save.backward(dy)
x_load = x.detach().clone().requires_grad_(True)
y_load = model_load(x_load)
y_load.backward(dy)
# Check that data tensors match
torch.testing.assert_close(y_load, y_save)
torch.testing.assert_close(x_load.grad, x_save.grad)
# Optimizer step
optim_save.step()
optim_load.step()
# Check that parameters match
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
torch.testing.assert_close(param_load, param_save)
if __name__ == "__main__":
# Assume script has been run with torchrun
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment