"src/vscode:/vscode.git/clone" did not exist on "13e48492f0aca759dda5056481d32b641af0450f"
Unverified Commit 2e025ab5 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Improvements in distributed Adam optimizer for Megatron (#1432)

* Improvements in distributed Adam optimizer for Megatron

Add option to allocate gradient buckets out of one large buffer. Add option to initialize params in user-provided order. Perform communication when saving optimizer state. Support param sync with any dtype.

* Style fixes in distributed Adam helper classes

Review suggestions from @crcrpar
parent fb21698e
...@@ -3,6 +3,7 @@ import contextlib ...@@ -3,6 +3,7 @@ import contextlib
import enum import enum
import importlib import importlib
import inspect import inspect
import io
import math import math
import threading import threading
...@@ -11,6 +12,10 @@ import amp_C ...@@ -11,6 +12,10 @@ import amp_C
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank
def _round_to_multiple(number, multiple, round_up=True):
"""Assumes arguments are positive integers"""
return (number+multiple-1 if round_up else number) // multiple * multiple
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
"""AdamW optimizer with ZeRO algorithm. """AdamW optimizer with ZeRO algorithm.
...@@ -49,8 +54,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -49,8 +54,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grad_sync_dtype (torch.dtype, optional): datatype for gradient grad_sync_dtype (torch.dtype, optional): datatype for gradient
synchronization (default: same as dtype) synchronization (default: same as dtype)
param_sync_dtype (torch.dtype, optional): datatype for param_sync_dtype (torch.dtype, optional): datatype for
parameter synchronization (default: same as parameter synchronization (default: same as dtype)
grad_sync_dtype)
device (torch.device, optional): device for optimizer state device (torch.device, optional): device for optimizer state
(default: cuda). Currently only supports GPU with one GPU (default: cuda). Currently only supports GPU with one GPU
per process. per process.
...@@ -75,6 +79,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -75,6 +79,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
(default: 100) (default: 100)
pipeline_size (int, optional): number of buckets to pipeline_size (int, optional): number of buckets to
synchronize simultaneously (default: 2) synchronize simultaneously (default: 2)
contiguous_grad_buffer (bool, optional): allocate gradient
buckets out of a large persistent buffer (default: False).
This allows individual parameter gradients to be accessed
externally (see grad_buffer_view function). It also
maximizes memory usage and may prevent overlapping
communication and compute.
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -86,6 +96,56 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -86,6 +96,56 @@ class DistributedFusedAdam(torch.optim.Optimizer):
""" """
class ParameterFragment:
"""Buffer ranges for a parameter fragment
Describes corresponding regions in parameter buffer and
parameter bucket.
"""
def __init__(
self,
param_group_id,
param_id,
bucket_id,
param_range,
bucket_range,
in_local_shard,
shard_range,
shard_bucket_range,
shard_param_range,
):
# Parameter group index
self.param_group_id = param_group_id
# Parameter index within parameter group
self.param_id = param_id
# Bucket index
self.bucket_id = bucket_id
# Range within flattened parameter buffer
self.param_range = param_range
# Range within bucket
self.bucket_range = bucket_range
# Whether fragment is in local shard of bucket
self.in_local_shard = in_local_shard
# Range within local shard
self.shard_range = shard_range
# Range of local fragment shard within bucket
self.shard_bucket_range = shard_bucket_range
# Range of local fragment shard within parameter
self.shard_param_range = shard_param_range
class StateBucket:
def __init__(self, shard_size, dtype, device):
"""Optimizer state for a bucket"""
# Buffer ranges corresponding to parameter fragments
self.fragments = []
# Local shard of parameters
self.params_shard = torch.zeros([shard_size], dtype=dtype, device=device)
# Local shard of first moment estimate
self.exp_avg_shard = torch.zeros([shard_size], dtype=dtype, device=device)
# Local shard of second moment estimate
self.exp_avg_sq_shard = torch.zeros([shard_size], dtype=dtype, device=device)
class GradientStatus(enum.Enum): class GradientStatus(enum.Enum):
"""Status of gradients within a bucket""" """Status of gradients within a bucket"""
# Gradients are ready to use # Gradients are ready to use
...@@ -97,6 +157,26 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -97,6 +157,26 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Asynchronous reduction is in progress # Asynchronous reduction is in progress
SYNCING = enum.auto() SYNCING = enum.auto()
class GradientBucket:
"""Gradient buffers and state for a bucket"""
def __init__(self):
# Local shard of gradients
self.grads_shard = None
# Local contribution to gradients
self.grads_bucket = None
# Buffer for gradient reduce-scatter
self.sync_grads_shard = None
# Status of gradients
self.status = DistributedFusedAdam.GradientStatus.READY
# Request object for asynchronous communication
self.sync_request = None
def sync_wait(self):
"""Wait for asynchronous communication to finish"""
if self.sync_request is not None:
self.sync_request.wait()
self.sync_request = None
_step_supports_amp_scaling = True _step_supports_amp_scaling = True
def __init__(self, def __init__(self,
...@@ -118,6 +198,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -118,6 +198,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
overlap_grad_sync=True, overlap_grad_sync=True,
bucket_cap_mb=100, bucket_cap_mb=100,
pipeline_size=2, pipeline_size=2,
contiguous_grad_buffer=False,
): ):
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay) betas=betas, eps=eps, weight_decay=weight_decay)
...@@ -131,12 +212,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -131,12 +212,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if grad_sync_dtype is None: if grad_sync_dtype is None:
grad_sync_dtype = dtype grad_sync_dtype = dtype
if param_sync_dtype is None: if param_sync_dtype is None:
param_sync_dtype = grad_sync_dtype param_sync_dtype = dtype
supported_dtypes = [ supported_dtypes = [
(torch.float32, torch.float16, torch.float16), (torch.float32, torch.float16),
(torch.float32, torch.float32, torch.float32), (torch.float32, torch.float32),
] ]
if (dtype, grad_sync_dtype, param_sync_dtype) not in supported_dtypes: if (dtype, grad_sync_dtype) not in supported_dtypes:
raise RuntimeError( raise RuntimeError(
'Invalid dtypes for DistributedFusedAdam ' 'Invalid dtypes for DistributedFusedAdam '
f'(dtype={dtype}, ' f'(dtype={dtype}, '
...@@ -176,6 +257,13 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -176,6 +257,13 @@ class DistributedFusedAdam(torch.optim.Optimizer):
f'distributed process group size = {self.distributed_size}, ' f'distributed process group size = {self.distributed_size}, '
f'redundant process group size = {self.redundant_size})' f'redundant process group size = {self.redundant_size})'
) )
try:
self._process_group_ranks = [
_get_global_rank(self.process_group, local_rank)
for local_rank in range(self.distributed_size)
]
except:
self._process_group_ranks = list(range(self.distributed_size))
# Use average reduction for grad sync # Use average reduction for grad sync
self.average_grad_sync = average_grad_sync self.average_grad_sync = average_grad_sync
...@@ -185,13 +273,15 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -185,13 +273,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.overlap_grad_sync = overlap_grad_sync self.overlap_grad_sync = overlap_grad_sync
# Number of buckets to synchronize at a time # Number of buckets to synchronize at a time
self.pipeline_size = pipeline_size self.pipeline_size = pipeline_size
# Allocate contiguous buffer for gradients
self.contiguous_grad_buffer = contiguous_grad_buffer
# Determine bucket sizes # Determine bucket sizes
dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8 dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8
self.alignment = 128 // dtype_size self.alignment = 128 // dtype_size
bucket_size = 1024*1024*bucket_cap_mb / dtype_size bucket_size = 1024*1024*bucket_cap_mb / dtype_size
shard_size = bucket_size / self.distributed_size shard_size = int(bucket_size / self.distributed_size)
shard_size = (int(shard_size) // self.alignment) * self.alignment shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False)
shard_size = max(shard_size, self.alignment) shard_size = max(shard_size, self.alignment)
bucket_size = shard_size * self.distributed_size bucket_size = shard_size * self.distributed_size
self.bucket_size = bucket_size self.bucket_size = bucket_size
...@@ -207,6 +297,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -207,6 +297,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.state['step'] = 0 self.state['step'] = 0
# Objects for gradient synchronization # Objects for gradient synchronization
self._grads_buckets = collections.defaultdict(self.GradientBucket)
self._grads_generated = set() self._grads_generated = set()
self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)] self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)]
...@@ -224,6 +315,9 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -224,6 +315,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._all_gather_no_copy = ( self._all_gather_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args 'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
) )
self._gather_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.gather).args
)
# Attach hooks for gradient synchronization # Attach hooks for gradient synchronization
self._register_post_backward_hooks() self._register_post_backward_hooks()
...@@ -236,20 +330,20 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -236,20 +330,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
""" """
self._num_grads = 0 self._num_grads = 0
grad_buffer_size = 0
self._lock = threading.Lock() self._lock = threading.Lock()
self._grad_accs = [] self._grad_accs = []
try:
root_rank = _get_global_rank(self.process_group, 0)
except:
root_rank = 0
for param_group_id, group in enumerate(self.param_groups): for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']): for param_id, param in enumerate(group['params']):
torch.distributed.broadcast( torch.distributed.broadcast(
param, param,
src=root_rank, src=self._process_group_ranks[0],
group=self.process_group, group=self.process_group,
) )
if param.requires_grad: if param.requires_grad:
self._num_grads += 1
# Callback after gradient is generated
def wrapper(p, p_group_id, p_id): def wrapper(p, p_group_id, p_id):
p_tmp = p.expand_as(p) p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc = p_tmp.grad_fn.next_functions[0][0]
...@@ -261,13 +355,57 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -261,13 +355,57 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._grad_copy(p) self._grad_copy(p)
if self.overlap_grad_sync: if self.overlap_grad_sync:
self._try_start_bucket_grad_sync( self._try_start_bucket_grad_sync(
[p], params=[p],
ignore_last_bucket=True, ignore_last_bucket=True,
) )
grad_acc.register_hook(reduction_hook) grad_acc.register_hook(reduction_hook)
self._grad_accs.append(grad_acc) self._grad_accs.append(grad_acc)
wrapper(param, param_group_id, param_id) wrapper(param, param_group_id, param_id)
self._num_grads += 1
# Gradient size, with padding for alignment
grad_size = _round_to_multiple(param.numel(), self.alignment)
grad_buffer_size += grad_size
# Allocate contiguous gradient buffer if needed
if self.contiguous_grad_buffer:
grad_buffer_size = _round_to_multiple(
grad_buffer_size,
self.bucket_size,
)
self._grad_buffer = torch.zeros(
[grad_buffer_size],
dtype=self.dtype,
device=self.device,
)
def init_params(self, params=None):
"""Initialize optimizer state for parameters
Arguments:
params (iterable, optional): parameters to initialize
(default: all parameters)
"""
# Default cases
if isinstance(params, torch.Tensor):
params = [params]
elif params is None:
params = []
for group in self.param_groups:
params.extend(group['params'])
# Get indices corresponding to parameters
id_map = dict()
for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']):
id_map[param] = (param_group_id, param_id)
# Initialize parameters
for param in params:
if param in id_map and 'fragments' not in self.state[param]:
param_group_id, param_id = id_map[param]
self._init_param_state(param, param_group_id, param_id)
def _init_param_state( def _init_param_state(
self, self,
...@@ -279,7 +417,9 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -279,7 +417,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Make sure there is at least one bucket # Make sure there is at least one bucket
if not self.state['buckets']: if not self.state['buckets']:
self._add_bucket() self.state['buckets'].append(
self.StateBucket(self.shard_size, self.dtype, self.device)
)
# Split parameter values into fragments # Split parameter values into fragments
# Note: Each fragment resides within a bucket # Note: Each fragment resides within a bucket
...@@ -289,29 +429,25 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -289,29 +429,25 @@ class DistributedFusedAdam(torch.optim.Optimizer):
while param_start < param_size: while param_start < param_size:
# Get current bucket # Get current bucket
if not self.state['buckets']:
self._add_bucket()
bucket_id = len(self.state['buckets']) - 1 bucket_id = len(self.state['buckets']) - 1
bucket = self.state['buckets'][bucket_id] bucket = self.state['buckets'][bucket_id]
fragment_id = len(bucket['fragments']) fragment_id = len(bucket.fragments)
# Determine fragment position within bucket # Determine fragment position within bucket
if fragment_id == 0: if fragment_id == 0:
bucket_start = 0 bucket_start = 0
else: else:
bucket_start = bucket['fragments'][-1]['bucket_range'][1] _, bucket_start = bucket.fragments[-1].bucket_range
bucket_start = ( bucket_start = _round_to_multiple(bucket_start, self.alignment)
(bucket_start + self.alignment - 1)
// self.alignment
* self.alignment
) # Pad until fragment is aligned
fragment_size = min(param_size-param_start, self.bucket_size-bucket_start) fragment_size = min(param_size-param_start, self.bucket_size-bucket_start)
param_end = param_start + fragment_size param_end = param_start + fragment_size
bucket_end = bucket_start + fragment_size bucket_end = bucket_start + fragment_size
# Create new bucket if current one is full # Create new bucket if current one is full
if fragment_size <= 0: if fragment_size <= 0:
self._add_bucket() self.state['buckets'].append(
self.StateBucket(self.shard_size, self.dtype, self.device)
)
continue continue
# Fragment position within local shard # Fragment position within local shard
...@@ -331,81 +467,55 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -331,81 +467,55 @@ class DistributedFusedAdam(torch.optim.Optimizer):
shard_param_start, shard_param_end = None, None shard_param_start, shard_param_end = None, None
# Record fragment info # Record fragment info
fragment = { fragment = self.ParameterFragment(
# Parameter group index param_group_id=param_group_id,
'param_group_id': param_group_id, param_id=param_id,
# Parameter index within parameter group bucket_id=bucket_id,
'param_id': param_id, param_range=(param_start,param_end),
# Bucket index bucket_range=(bucket_start,bucket_end),
'bucket_id': bucket_id, in_local_shard=in_local_shard,
# Range within flattened parameter buffer shard_range=(shard_start,shard_end),
'param_range': (param_start,param_end), shard_bucket_range=(shard_bucket_start,shard_bucket_end),
# Range within bucket shard_param_range=(shard_param_start,shard_param_end),
'bucket_range': (bucket_start,bucket_end), )
# Whether fragment is in local shard of bucket
'in_local_shard': in_local_shard,
# Range within local shard
'shard_range': (shard_start,shard_end),
# Range of local fragment shard within bucket
'shard_bucket_range': (shard_bucket_start,shard_bucket_end),
# Range of local fragment shard within parameter
'shard_param_range': (shard_param_start,shard_param_end),
}
# Record fragment info
self.state[param]['fragments'].append(fragment) self.state[param]['fragments'].append(fragment)
bucket['fragments'].append(fragment) bucket.fragments.append(fragment)
param_start = param_end param_start = param_end
# Initialize master param buffer # Initialize master param buffer
for fragment in self.state[param]['fragments']: for fragment in self.state[param]['fragments']:
if fragment['in_local_shard']: if fragment.in_local_shard:
bucket_id = fragment['bucket_id'] bucket = self.state['buckets'][fragment.bucket_id]
bucket = self.state['buckets'][bucket_id] param_start, param_end = fragment.shard_param_range
param_start, param_end = fragment['shard_param_range'] shard_start, shard_end = fragment.shard_range
shard_start, shard_end = fragment['shard_range']
model_param_fragment = param.view(-1)[param_start:param_end] model_param_fragment = param.view(-1)[param_start:param_end]
master_param_fragment = bucket['params_shard'][shard_start:shard_end] master_param_fragment = bucket.params_shard[shard_start:shard_end]
master_param_fragment.copy_(model_param_fragment) master_param_fragment.copy_(model_param_fragment)
def _add_bucket(self): def zero_grad(self, set_to_none=True):
"""Construct a bucket for optimizer state""" """Clear parameter gradients"""
self.state['buckets'].append({
# Parameter fragments associated with bucket
'fragments': [],
# Gradient buffers
'grads_shard': None,
'grads_bucket': None,
'curr_grads_shard': None, # For current micro-batch
# Optimizer state
'params_shard': torch.zeros([self.shard_size], dtype=self.dtype, device=self.device),
'exp_avg_shard': torch.zeros([self.shard_size], dtype=self.dtype, device=self.device),
'exp_avg_sq_shard': torch.zeros([self.shard_size], dtype=self.dtype, device=self.device),
# Status of parameter gradients
'gradient_status': self.GradientStatus.READY,
# Distributed request object for gradient synchronization # Reset bucket buffers
'grad_sync_request': None, self._grads_buckets.clear()
}) # Construct views into contiguous grad buffer, if needed
if self.contiguous_grad_buffer:
self._grad_buffer.zero_()
for bucket_id in range(len(self.state['buckets'])):
bucket_start = bucket_id * self.bucket_size
bucket_end = bucket_start + self.bucket_size
bucket = self._grads_buckets[bucket_id]
bucket.grads_bucket = self._grad_buffer[bucket_start:bucket_end]
def zero_grad(self, set_to_none=True): # Reset param grads
"""Clear parameter gradients"""
for group in self.param_groups: for group in self.param_groups:
for param in group['params']: for param in group['params']:
if param.grad is None or set_to_none: if param.grad is None or set_to_none:
param.grad = None param.grad = None
else: else:
param.grad.zero_() param.grad.zero_()
for bucket in self.state['buckets']:
bucket['grads_shard'] = None # Reset other state
bucket['grads_bucket'] = None
bucket['curr_grads_shard'] = None
bucket['gradient_status'] = self.GradientStatus.READY
self._grads_generated = set() self._grads_generated = set()
self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device) self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
self._grad_norm = None self._grad_norm = None
...@@ -417,51 +527,80 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -417,51 +527,80 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for fragment in self.state[param]['fragments']: for fragment in self.state[param]['fragments']:
# Get fragment position # Get fragment position
bucket_id = fragment['bucket_id'] bucket_id = fragment.bucket_id
bucket = self.state['buckets'][bucket_id] bucket = self._grads_buckets[bucket_id]
grad_start, grad_end = fragment['param_range'] grad_start, grad_end = fragment.param_range
bucket_start, bucket_end = fragment['bucket_range'] bucket_start, bucket_end = fragment.bucket_range
# Set reduction status # Set reduction status
if bucket['gradient_status'] == self.GradientStatus.SYNCING: if bucket.status == self.GradientStatus.SYNCING:
self._finish_bucket_grad_sync() self._finish_bucket_grad_sync()
bucket['gradient_status'] = self.GradientStatus.PARTIALLY_FILLED bucket.status = self.GradientStatus.PARTIALLY_FILLED
# Allocate gradient buffer if needed # Allocate gradient buffer if needed
if bucket['grads_bucket'] is None: if bucket.grads_bucket is None:
bucket['grads_bucket'] = torch.zeros( if self.contiguous_grad_buffer:
grad_buffer_start = bucket_id * self.bucket_size
grad_buffer_end = grad_buffer_start + self.bucket_size
bucket.grads_bucket = self._grad_buffer[grad_buffer_start:grad_buffer_end]
else:
bucket.grads_bucket = torch.empty(
[self.bucket_size], [self.bucket_size],
dtype=self.grad_sync_dtype, dtype=self.grad_sync_dtype,
device=self.device, device=self.device,
) )
bucket.grads_bucket.zero_()
# Copy param grad to bucket # Copy param grad to bucket
if param.grad is not None: if param.grad is not None:
scale = 1/self.process_group_size if self.average_grad_sync else 1.0
grad_in = param.grad.detach().view(-1)[grad_start:grad_end] grad_in = param.grad.detach().view(-1)[grad_start:grad_end]
grad_out = bucket['grads_bucket'][bucket_start:bucket_end] grad_out = bucket.grads_bucket[bucket_start:bucket_end]
grad_out.add_(grad_in, alpha=scale) if grad_in.data_ptr() != grad_out.data_ptr():
grad_out.add_(grad_in)
# Free param grad buffer # Free param grad buffer
param.grad = None param.grad = None
def grad_buffer_view(self, param):
"""Construct view into grad buffer corresponding to param
Assumes optimizer is using a contiguous grad buffer.
"""
assert self.contiguous_grad_buffer
# Figure out corresponding position in grad buffer
param_fragments = self.state[param]['fragments']
start_bucket_id = param_fragments[0].bucket_id
start_bucket_offset, _ = param_fragments[0].bucket_range
end_bucket_id = param_fragments[-1].bucket_id
_, end_bucket_offset = param_fragments[-1].bucket_range
buffer_start = start_bucket_id * self.bucket_size + start_bucket_offset
buffer_end = end_bucket_id * self.bucket_size + end_bucket_offset
# Construct view into grad buffer
flat_buffer = self._grad_buffer[buffer_start:buffer_end]
return flat_buffer.detach().view(param.size())
def _force_bucket_grad_sync(self): def _force_bucket_grad_sync(self):
"""Ensure that all gradient buckets are synchronized""" """Ensure that all gradient buckets are synchronized"""
# Synchronize all unsynchronized buckets # Synchronize all unsynchronized buckets
self._finish_bucket_grad_sync() self._finish_bucket_grad_sync()
buckets = [ buckets = [
bucket for bucket in self.state['buckets'] bucket
if bucket['gradient_status'] != self.GradientStatus.READY for bucket_id, bucket in sorted(self._grads_buckets.items())
if bucket.status != self.GradientStatus.READY
] ]
if buckets: if buckets:
self._start_bucket_grad_sync(buckets) self._start_bucket_grad_sync(buckets)
self._finish_bucket_grad_sync() self._finish_bucket_grad_sync()
# Fill any unfilled buckets with zeros # Fill any unsynchronized gradients with zeros
for bucket in self.state['buckets']: for bucket_id in range(len(self.state['buckets'])):
if bucket['grads_shard'] is None: bucket = self._grads_buckets[bucket_id]
bucket['grads_shard'] = torch.zeros( if bucket.grads_shard is None:
bucket.grads_shard = torch.zeros(
[self.shard_size], [self.shard_size],
dtype=self.grad_sync_dtype, dtype=self.grad_sync_dtype,
device=self.device, device=self.device,
...@@ -495,75 +634,91 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -495,75 +634,91 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for param in params: for param in params:
self._grads_generated.add(param) self._grads_generated.add(param)
for fragment in self.state[param]['fragments']: for fragment in self.state[param]['fragments']:
bucket_id = fragment['bucket_id'] bucket_id = fragment.bucket_id
bucket = self.state['buckets'][bucket_id] bucket_fragments = self.state['buckets'][bucket_id].fragments
is_filled = True is_filled = True
for other_fragment in reversed(bucket['fragments']): for other_fragment in reversed(bucket_fragments):
param_group_id = other_fragment['param_group_id'] param_group_id = other_fragment.param_group_id
param_id = other_fragment['param_id'] param_id = other_fragment.param_id
other_param = self.param_groups[param_group_id]['params'][param_id] other_param = self.param_groups[param_group_id]['params'][param_id]
if other_param not in self._grads_generated: if other_param not in self._grads_generated:
is_filled = False is_filled = False
break break
if is_filled: if is_filled:
bucket['gradient_status'] = self.GradientStatus.FULLY_FILLED bucket = self._grads_buckets[bucket_id]
bucket.status = self.GradientStatus.FULLY_FILLED
# Launch reductions if enough buckets are ready # Launch reductions if enough buckets are ready
if len(self._grads_generated) == self._num_grads: if len(self._grads_generated) == self._num_grads:
self._force_bucket_grad_sync() self._force_bucket_grad_sync()
else: else:
all_buckets = self.state['buckets'] filled_buckets = []
if ignore_last_bucket: for bucket_id, bucket in sorted(self._grads_buckets.items()):
all_buckets = all_buckets[:-1] if ignore_last_bucket and bucket_id == len(self.state['buckets'])-1:
filled_buckets = [ continue
bucket if bucket.status == self.GradientStatus.FULLY_FILLED:
for bucket in all_buckets filled_buckets.append(bucket)
if bucket['gradient_status'] == self.GradientStatus.FULLY_FILLED pipeline_size = _round_to_multiple(
] len(filled_buckets),
pipeline_size = (len(filled_buckets) // self.pipeline_size) * self.pipeline_size self.pipeline_size,
)
if pipeline_size > 0: if pipeline_size > 0:
self._start_bucket_grad_sync(filled_buckets[:pipeline_size]) self._start_bucket_grad_sync(filled_buckets[:pipeline_size])
def _start_bucket_grad_sync(self, buckets): def _start_bucket_grad_sync(self, buckets):
"""Synchronize gradients in buckets """Synchronize gradient buckets
Gradient synchronization is asynchronous. Involves Gradient synchronization is asynchronous. Involves
reduce-scatter over distributed process group and allreduce reduce-scatter over distributed process group and allreduce
over redundant process group. over redundant process group.
""" """
# Call recursively if more buckets than streams
while len(buckets) > self.pipeline_size:
self._start_bucket_grad_sync(buckets[:self.pipeline_size])
buckets = buckets[self.pipeline_size:]
self._finish_bucket_grad_sync() self._finish_bucket_grad_sync()
# Reduction operation
if self.average_grad_sync:
reduce_op = torch.distributed.ReduceOp.AVG
else:
reduce_op = torch.distributed.ReduceOp.SUM
# Reduce gradients # Reduce gradients
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams: for stream in self._pipeline_streams:
stream.wait_stream(torch.cuda.current_stream()) stream.wait_stream(main_stream)
for i, bucket in enumerate(buckets): for i, bucket in enumerate(buckets):
bucket['gradient_status'] = self.GradientStatus.SYNCING bucket.status = self.GradientStatus.SYNCING
stream = self._pipeline_streams[i % self.pipeline_size] stream = self._pipeline_streams[i % self.pipeline_size]
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
# Reduce-scatter over distributed process group # Reduce-scatter over distributed process group
bucket.sync_wait()
if self.distributed_size == 1: if self.distributed_size == 1:
bucket['curr_grads_shard'] = bucket['grads_bucket'] bucket.sync_grads_shard = bucket.grads_bucket
bucket['grad_sync_request'] = None
else: else:
bucket['curr_grads_shard'] = torch.zeros( with torch.cuda.stream(main_stream):
bucket.sync_grads_shard = torch.zeros(
[self.shard_size], [self.shard_size],
dtype=self.grad_sync_dtype, dtype=self.grad_sync_dtype,
device=self.device, device=self.device,
) )
grads_bucket_shards = [ grads_bucket_shards = [
bucket['grads_bucket'][i*self.shard_size:(i+1)*self.shard_size] bucket.grads_bucket[i*self.shard_size:(i+1)*self.shard_size]
for i in range(self.distributed_size) for i in range(self.distributed_size)
] ]
if self._reduce_scatter_no_copy: if self._reduce_scatter_no_copy:
no_copy_kwarg = { 'no_copy': True } no_copy_kwarg = { 'no_copy': True }
else: else:
no_copy_kwarg = {} no_copy_kwarg = {}
bucket['grad_sync_request'] = ( bucket.sync_request = (
torch.distributed.reduce_scatter( torch.distributed.reduce_scatter(
bucket['curr_grads_shard'], bucket.sync_grads_shard,
grads_bucket_shards, grads_bucket_shards,
op=reduce_op,
group=self.distributed_process_group, group=self.distributed_process_group,
async_op=True, async_op=True,
**no_copy_kwarg, **no_copy_kwarg,
...@@ -576,11 +731,11 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -576,11 +731,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# submitted in a consistent order. There could be race # submitted in a consistent order. There could be race
# conditions if wait doesn't finish in order. # conditions if wait doesn't finish in order.
if self.redundant_size > 1: if self.redundant_size > 1:
if bucket['grad_sync_request'] is not None: bucket.sync_wait()
bucket['grad_sync_request'].wait() bucket.sync_request = (
bucket['grad_sync_request'] = (
torch.distributed.all_reduce( torch.distributed.all_reduce(
bucket['curr_grads_shard'], bucket.sync_grads_shard,
op=reduce_op,
group=self.redundant_process_group, group=self.redundant_process_group,
async_op=True, async_op=True,
) )
...@@ -588,26 +743,22 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -588,26 +743,22 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def _finish_bucket_grad_sync(self): def _finish_bucket_grad_sync(self):
"""Wait for any gradient synchronizations that are in progress""" """Wait for any gradient synchronizations that are in progress"""
for bucket in self.state['buckets']: for bucket_id, bucket in sorted(self._grads_buckets.items()):
if bucket['gradient_status'] == self.GradientStatus.SYNCING: if bucket.status == self.GradientStatus.SYNCING:
# Finish asynchronous communication # Finish asynchronous communication
if bucket['grad_sync_request'] is not None: bucket.sync_wait()
bucket['grad_sync_request'].wait()
bucket['grad_sync_request'] = None
# Accumulate gradient in local shard # Accumulate gradient in local shard
if bucket['grads_shard'] is None: if bucket.grads_shard is None:
bucket['grads_shard'] = bucket['curr_grads_shard'] bucket.grads_shard = bucket.sync_grads_shard
else: else:
bucket['grads_shard'].add_(bucket['curr_grads_shard']) bucket.grads_shard.add_(bucket.sync_grads_shard)
bucket.grads_bucket = None
# Deallocate buffers for gradient synchronization bucket.sync_grads_shard = None
bucket['grads_bucket'] = None
bucket['curr_grads_shard'] = None
# Reset status # Reset status
bucket['gradient_status'] = self.GradientStatus.READY bucket.status = self.GradientStatus.READY
# Cached gradient norm has been invalidated # Cached gradient norm has been invalidated
self._grad_norm = None self._grad_norm = None
...@@ -642,14 +793,14 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -642,14 +793,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def grad_sync(self): def grad_sync(self):
"""Ensure that all gradients are synchronized""" """Ensure that all gradients are synchronized"""
for bucket in self.state['buckets']: for bucket in self.state['buckets']:
for fragment in bucket['fragments']: for fragment in bucket.fragments:
param_group_id = fragment['param_group_id'] param_group_id = fragment.param_group_id
param_id = fragment['param_id'] param_id = fragment.param_id
param = self.param_groups[param_group_id]['params'][param_id] param = self.param_groups[param_group_id]['params'][param_id]
if param.grad is not None: if param.grad is not None:
self._grad_copy(param) self._grad_copy(param)
self._try_start_bucket_grad_sync( self._try_start_bucket_grad_sync(
[param], params=[param],
ignore_last_bucket=False, ignore_last_bucket=False,
) )
self._force_bucket_grad_sync() self._force_bucket_grad_sync()
...@@ -676,7 +827,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -676,7 +827,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grad_norm_sq = multi_tensor_applier( grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm, amp_C.multi_tensor_l2norm,
dummy_overflow_buf, dummy_overflow_buf,
[[bucket['grads_shard'] for bucket in self.state['buckets']]], [[bucket.grads_shard for bucket in self._grads_buckets.values()]],
False, False,
)[0] ** 2 )[0] ** 2
else: else:
...@@ -684,11 +835,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -684,11 +835,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grads = [] grads = []
for param in parameters: for param in parameters:
for fragment in self.state[param]['fragments']: for fragment in self.state[param]['fragments']:
if fragment['in_local_shard']: if fragment.in_local_shard:
bucket_id = fragment['bucket_id'] bucket = self._grads_buckets[fragment.bucket_id]
bucket = self.state['buckets'][bucket_id] shard_start, shard_end = fragment.shard_range
shard_start, shard_end = fragment['shard_range'] grads.append(bucket.grads_shard[shard_start:shard_end])
grads.append(bucket['grads_shard'][shard_start:shard_end])
if grads: if grads:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
grad_norm_sq = multi_tensor_applier( grad_norm_sq = multi_tensor_applier(
...@@ -798,45 +948,79 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -798,45 +948,79 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._inv_grad_scale *= grad_scaler._scale self._inv_grad_scale *= grad_scaler._scale
inv_grad_scale = self._inv_grad_scale.item() inv_grad_scale = self._inv_grad_scale.item()
# Construct workspace buffers
params_bucket_buffers = [
torch.empty(
[self.bucket_size],
dtype=self.param_sync_dtype,
device=self.device,
)
for _ in range(self.pipeline_size)
]
if self.grad_sync_dtype == self.param_sync_dtype:
shard_start = self.distributed_rank * self.shard_size
shard_end = shard_start + self.shard_size
params_copy_buffers = [
params_bucket[shard_start:shard_end]
for params_bucket in params_bucket_buffers
]
else:
params_copy_buffers = [
torch.empty(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
for _ in range(self.pipeline_size)
]
# Apply optimizer step to each bucket and synchronize params # Apply optimizer step to each bucket and synchronize params
self.state['step'] += 1 self.state['step'] += 1
current_stream = torch.cuda.current_stream() main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams: for stream in self._pipeline_streams:
stream.wait_stream(current_stream) stream.wait_stream(main_stream)
for i, bucket in enumerate(self.state['buckets']): for bucket_id in range(len(self.state['buckets'])):
stream = self._pipeline_streams[i % self.pipeline_size] stream_id = bucket_id % self.pipeline_size
# Bucket buffers
fragments = self.state['buckets'][bucket_id].fragments
shard_start = self.distributed_rank * self.shard_size
shard_end = shard_start + self.shard_size
params_bucket = params_bucket_buffers[stream_id]
params_bucket_shard = params_bucket[shard_start:shard_end]
params_shard = self.state['buckets'][bucket_id].params_shard
params_copy = params_copy_buffers[stream_id]
exp_avg = self.state['buckets'][bucket_id].exp_avg_shard
exp_avg_sq = self.state['buckets'][bucket_id].exp_avg_sq_shard
grads = self._grads_buckets[bucket_id].grads_shard
# Perform compute on parallel stream
stream = self._pipeline_streams[stream_id]
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
# Buffer for param sync
params_shard_copy = torch.zeros(
[self.shard_size],
dtype=self.param_sync_dtype,
device=self.device,
)
# Find param fragments in local shard # Find param fragments in local shard
buffers = collections.defaultdict(list) # p, m, v, g, p_copy buffers = collections.defaultdict(list) # p, m, v, g, p_copy
for fragment in bucket['fragments']: for fragment in fragments:
if fragment['in_local_shard']: if fragment.in_local_shard:
param_group_id = fragment['param_group_id'] param_group_id = fragment.param_group_id
shard_start, shard_end = fragment['shard_range'] shard_start, shard_end = fragment.shard_range
buffers[param_group_id].append([ buffers[param_group_id].append([
bucket['params_shard'][shard_start:shard_end], params_shard[shard_start:shard_end],
bucket['exp_avg_shard'][shard_start:shard_end], exp_avg[shard_start:shard_end],
bucket['exp_avg_sq_shard'][shard_start:shard_end], exp_avg_sq[shard_start:shard_end],
bucket['grads_shard'][shard_start:shard_end], grads[shard_start:shard_end],
params_shard_copy[shard_start:shard_end], params_copy[shard_start:shard_end],
]) ])
# Fuse param fragments if possible # Fuse param fragments if possible
if len(buffers) == 1: if len(buffers) == 1:
group_id = list(buffers.keys())[0] group_id = list(buffers.keys())[0]
buffers[group_id] = [( buffers[group_id] = [(
bucket['params_shard'], params_shard,
bucket['exp_avg_shard'], exp_avg,
bucket['exp_avg_sq_shard'], exp_avg_sq,
bucket['grads_shard'], grads,
params_shard_copy, params_copy,
)] )]
# Apply optimizer step to each param group # Apply optimizer step to each param group
...@@ -873,54 +1057,53 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -873,54 +1057,53 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.state['step'], self.state['step'],
1, # Set to 0 to apply eps inside sqrt 1, # Set to 0 to apply eps inside sqrt
) )
del group_buffers
# Deallocate buffers # Cast parameter dtype if needed
del buffers if params_copy.data_ptr() != params_bucket_shard.data_ptr():
params_bucket_shard.copy_(params_copy)
# Allgather updated parameters # Allgather updated parameters
if self.distributed_size == 1: if self.distributed_size > 1:
params_bucket = params_shard_copy all_params_bucket_shards = [
else:
params_bucket = torch.zeros(
[self.bucket_size],
dtype=self.param_sync_dtype,
device=self.device,
)
params_bucket_shards = [
params_bucket[i*self.shard_size:(i+1)*self.shard_size] params_bucket[i*self.shard_size:(i+1)*self.shard_size]
for i in range(self.distributed_size) for i in range(self.distributed_size)
] ]
params_bucket_shards[self.distributed_rank].copy_(params_shard_copy)
if self._all_gather_no_copy: if self._all_gather_no_copy:
no_copy_kwarg = { 'no_copy': True } no_copy_kwarg = { 'no_copy': True }
else: else:
no_copy_kwarg = {} no_copy_kwarg = {}
torch.distributed.all_gather( torch.distributed.all_gather(
params_bucket_shards, all_params_bucket_shards,
params_bucket_shards[self.distributed_rank], params_bucket_shard,
group=self.distributed_process_group, group=self.distributed_process_group,
**no_copy_kwarg, **no_copy_kwarg,
) )
del params_shard_copy
# Copy values to param buffers # Copy values to param buffers
buffers = collections.defaultdict(list) # param_in, param_out buffers = collections.defaultdict(list) # param_in, param_out
for fragment in bucket['fragments']: for fragment in fragments:
param_group_id = fragment['param_group_id'] param_group_id = fragment.param_group_id
param_id = fragment['param_id'] param_id = fragment.param_id
param = self.param_groups[param_group_id]['params'][param_id] param = self.param_groups[param_group_id]['params'][param_id]
bucket_start, bucket_end = fragment['bucket_range'] bucket_start, bucket_end = fragment.bucket_range
param_start, param_end = fragment['param_range'] param_start, param_end = fragment.param_range
buffers[(param.is_cuda, param.dtype)].append(( param_in = params_bucket[bucket_start:bucket_end]
params_bucket[bucket_start:bucket_end], param_out = param.detach().view(-1)[param_start:param_end]
param.detach().view(-1)[param_start:param_end], if param_in.dtype == param_out.dtype:
)) # Just copy bytes if buffers have same type
param_in = param_in.view(torch.uint8)
param_out = param_out.view(torch.uint8)
buffers[(param.is_cuda, param.dtype)].append(
(param_in, param_out)
)
for (is_cuda, dtype), dtype_buffers in buffers.items(): for (is_cuda, dtype), dtype_buffers in buffers.items():
fused_kernel_dtypes = (torch.float32, torch.float16, torch.uint8) fused_kernel_dtypes = (
if (is_cuda self.param_sync_dtype,
and dtype in fused_kernel_dtypes torch.float32,
and self.param_sync_dtype in fused_kernel_dtypes): torch.float16,
torch.uint8,
)
if is_cuda and dtype in fused_kernel_dtypes:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt, fused_adam_cuda.maybe_cast_mt,
...@@ -930,11 +1113,168 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -930,11 +1113,168 @@ class DistributedFusedAdam(torch.optim.Optimizer):
else: else:
for param_in, param_out in dtype_buffers: for param_in, param_out in dtype_buffers:
param_out.copy_(param_in) param_out.copy_(param_in)
del dtype_buffers
del params_bucket, buffers
# Synchronize pipeline streams # Synchronize pipeline streams
for stream in self._pipeline_streams: for stream in self._pipeline_streams:
current_stream.wait_stream(stream) main_stream.wait_stream(stream)
return loss return loss
def state_dict(self, gather_on_root=True):
"""Get dictionary containing optimizer state
Default behavior is to perform communication so that the
entire optimizer state is returned on the root rank in the
process group. In this case, all ranks in the process group
must enter this function and no value is returned on non-root
ranks.
Arguments:
gather_on_root (bool, optional): Gather state from all
ranks on the root rank (default: True)
"""
state_dict = super().state_dict()
if not gather_on_root:
return state_dict
# Export local state to byte string
state_bytes = io.BytesIO()
torch.save(state_dict, state_bytes)
state_bytes.seek(0)
state_bytes_view = state_bytes.getbuffer()
# Get data sizes on all ranks
local_state_size = len(state_bytes_view)
state_sizes = [None] * self.distributed_size
torch.distributed.all_gather_object(
state_sizes,
local_state_size,
group=self.process_group,
)
max_state_size = max(state_sizes)
# Construct workspace buffers
chunk_size = self.shard_size * torch.finfo(self.grad_sync_dtype).bits // 8
if self.distributed_rank == 0:
gathered_state_bytes = [state_bytes.getvalue()]
gathered_state_bytes.extend(bytearray(size) for size in state_sizes[1:])
gathered_chunks_buffers = [
torch.empty(
[chunk_size * self.distributed_size],
dtype=torch.uint8,
device=self.device,
)
for _ in range(self.pipeline_size)
]
else:
chunk_buffers = [
torch.empty(
[chunk_size],
dtype=torch.uint8,
device=self.device,
)
for _ in range(self.pipeline_size)
]
# Split data into chunks and gather on root rank
# Note: Assuming we are using the NCCL backend, communication
# must happen on the GPU. We split the data into fixed-size
# chunks so that the GPU memory usage is limited to
# (chunk_size * distributed_size) bytes.
# TODO: Avoid chunking with direct communication between CPUs
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(main_stream)
for stream_id, offset in enumerate(range(0, max_state_size, chunk_size)):
stream_id %= self.pipeline_size
# Buffers for chunk
if self.distributed_rank == 0:
gathered_chunks = [
gathered_chunks_buffers[stream_id][i*chunk_size:(i+1)*chunk_size]
for i in range(self.distributed_size)
]
else:
chunk = chunk_buffers[stream_id]
# Perform communication on parallel stream
stream = self._pipeline_streams[stream_id]
with torch.cuda.stream(stream):
# Copy to GPU
if self.distributed_rank != 0 and offset < local_state_size:
local_chunk_size = min(chunk_size, local_state_size-offset)
chunk[:local_chunk_size].copy_(
torch.frombuffer(
state_bytes_view,
dtype=torch.uint8,
count=local_chunk_size,
offset=offset,
),
non_blocking=True,
)
# Gather on root
if self.distributed_rank == 0:
if self._gather_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
torch.distributed.gather(
gathered_chunks[0],
gathered_chunks,
dst=self._process_group_ranks[0],
group=self.process_group,
**no_copy_kwarg,
)
else:
torch.distributed.gather(
chunk,
dst=self._process_group_ranks[0],
group=self.process_group,
)
# Copy back to CPU
if self.distributed_rank == 0:
for rank in range(1, self.distributed_size):
if offset < state_sizes[rank]:
rank_chunk_size = min(chunk_size, state_sizes[rank]-offset)
torch.frombuffer(
gathered_state_bytes[rank],
dtype=torch.uint8,
count=rank_chunk_size,
offset=offset,
).copy_(
gathered_chunks[rank][:rank_chunk_size],
non_blocking=True,
)
# Synchronize GPU
for stream in self._pipeline_streams:
main_stream.wait_stream(stream)
main_stream.synchronize()
# Return gathered state data on root rank
if self.distributed_rank == 0:
return {'gathered_states': gathered_state_bytes}
else:
return None
def load_state_dict(self, state_dict):
"""Load optimizer state"""
# State dict contains state for all ranks
if 'gathered_states' in state_dict:
# Deallocate distributed optimizer state to reduce GPU
# memory usage
if 'buckets' in self.state:
del self.state['buckets']
# Get state for current rank and parse byte string
state_bytes = state_dict['gathered_states'][self.distributed_rank]
state_bytes = io.BytesIO(state_bytes)
state_dict = torch.load(state_bytes)
return super().load_state_dict(state_dict)
from contextlib import contextmanager from contextlib import contextmanager
import io
import os import os
import torch import torch
...@@ -25,6 +26,7 @@ def make_models( ...@@ -25,6 +26,7 @@ def make_models(
num_layers, num_layers,
size, size,
dtype=torch.float32, dtype=torch.float32,
param_sync_dtype=None,
device='cuda', device='cuda',
overlap_communication=True, overlap_communication=True,
): ):
...@@ -61,6 +63,8 @@ def make_models( ...@@ -61,6 +63,8 @@ def make_models(
], ],
overlap_grad_sync=overlap_communication, overlap_grad_sync=overlap_communication,
bucket_cap_mb=71/(4*1024*1024), bucket_cap_mb=71/(4*1024*1024),
dtype=torch.float32,
param_sync_dtype=param_sync_dtype,
**optim_args, **optim_args,
) )
...@@ -87,6 +91,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -87,6 +91,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
overlap_communication=True, overlap_communication=True,
use_nosync=True, use_nosync=True,
dtype=torch.float32, dtype=torch.float32,
param_sync_dtype=None,
device='cuda', device='cuda',
rtol=None, rtol=None,
atol=None, atol=None,
...@@ -99,6 +104,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -99,6 +104,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
num_layers, num_layers,
layer_size, layer_size,
dtype=dtype, dtype=dtype,
param_sync_dtype=param_sync_dtype,
device=device, device=device,
overlap_communication=overlap_communication, overlap_communication=overlap_communication,
) )
...@@ -172,6 +178,14 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -172,6 +178,14 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
atol=1e-2, atol=1e-2,
) )
def test_matches_pytorch_allgather_fp16(self):
self.test_matches_pytorch(
dtype=torch.float32,
param_sync_dtype=torch.float16,
rtol=1e-2,
atol=1e-2,
)
def test_raises_on_mismatch(self): def test_raises_on_mismatch(self):
torch.manual_seed(self.seed + self.rank) torch.manual_seed(self.seed + self.rank)
...@@ -277,6 +291,101 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -277,6 +291,101 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_model.parameters()): dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param) torch.testing.assert_close(dist_param, ref_param)
def test_checkpoint(self):
# Construct two models with same config and different params
num_layers = 5
layer_size = 2
torch.manual_seed(self.seed + self.rank)
_, _, model_save, optim_save = make_models(num_layers, layer_size)
_, _, model_load, optim_load = make_models(num_layers, layer_size)
# Train one of the models
num_steps = 3
micro_batch_steps = 2
batch_size = 4
for step in range(num_steps):
optim_save.zero_grad()
for micro_step in range(micro_batch_steps):
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.cuda()
dy = dy.cuda()
y = model_save(x)
y.backward(dy)
optim_save.step()
# Make sure models are different
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
self.assertRaises(
AssertionError,
torch.testing.assert_close,
param_load, param_save,
)
# Save state on root rank and load on all ranks
state_dict = {
'model': model_save.state_dict(),
'optim': optim_save.state_dict(),
}
if self.rank == 0:
state_bytes = io.BytesIO()
torch.save(state_dict, state_bytes)
state_bytes = [state_bytes.getvalue()]
else:
state_bytes = [None]
torch.distributed.broadcast_object_list(state_bytes, src=0)
state_bytes = io.BytesIO(state_bytes[0])
state_dict = torch.load(state_bytes, map_location='cuda')
model_load.load_state_dict(state_dict['model'])
optim_load.load_state_dict(state_dict['optim'])
# Make sure models are identical
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
torch.testing.assert_close(param_load, param_save)
# Train both models
num_steps = 3
micro_batch_steps = 3
batch_size = 5
for step in range(num_steps):
# Reset gradients
optim_save.zero_grad()
optim_load.zero_grad()
# Forward and backward passes
for micro_step in range(micro_batch_steps):
# Synthetic data
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.cuda()
dy = dy.cuda()
# Forward and backward pass
x_save = x.detach().clone().requires_grad_(True)
y_save = model_save(x_save)
y_save.backward(dy)
x_load = x.detach().clone().requires_grad_(True)
y_load = model_load(x_load)
y_load.backward(dy)
# Check that data tensors match
torch.testing.assert_close(y_load, y_save)
torch.testing.assert_close(x_load.grad, x_save.grad)
# Optimizer step
optim_save.step()
optim_load.step()
# Check that parameters match
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
torch.testing.assert_close(param_load, param_save)
if __name__ == "__main__": if __name__ == "__main__":
# Assume script has been run with torchrun # Assume script has been run with torchrun
common_utils.run_tests() common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment